Commit 39a6d0e6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merging in main branch

parents d8ee9c5f 84659c93
......@@ -17,9 +17,11 @@ from functools import partial
from typing import Dict, Text, Tuple
import torch
import jax.numpy as jnp
from openfold.np import residue_constants as rc
from openfold.utils import geometry, tensor_utils
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
import numpy as np
......
......@@ -6,8 +6,8 @@ class EarlyStoppingVerbose(EarlyStopping):
The default EarlyStopping callback's verbose mode is too verbose.
This class outputs a message only when it's getting ready to stop.
"""
def _evalute_stopping_criteria(self, *args):
should_stop, reason = super()._evalute_stopping_criteria(*args)
def _evalute_stopping_criteria(self, *args, **kwargs):
should_stop, reason = super()._evalute_stopping_criteria(*args, **kwargs)
if(should_stop):
rank_zero_info(f"{reason}\n")
......
......@@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Any, Tuple, List, Callable, Optional
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if(deepspeed_is_installed):
import deepspeed
import deepspeed
import torch
import torch.utils.checkpoint
from typing import Any, Tuple, List, Callable, Optional
BLOCK_ARG = Any
......@@ -23,7 +27,11 @@ BLOCK_ARGS = List[BLOCK_ARG]
def get_checkpoint_fn():
if(deepspeed.checkpointing.is_configured()):
deepspeed_is_configured = (
deepspeed_is_installed and
deepspeed.checkpointing.is_configured()
)
if(deepspeed_is_configured):
checkpoint = deepspeed.checkpointing.checkpoint
else:
checkpoint = torch.utils.checkpoint.checkpoint
......@@ -73,7 +81,7 @@ def checkpoint_blocks(
# Avoids mishaps when the blocks take just one argument
args = wrap(args)
if blocks_per_ckpt is None:
if blocks_per_ckpt is None or not torch.is_grad_enabled():
return exec(blocks, args)
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
......
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import logging
import math
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
import torch
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
)
def _fetch_dims(tree):
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if(start_edges is None):
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if(end_edges is None):
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if(len(start) == 0):
return [tuple()]
elif(len(start) == 1):
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s,e in zip(start, end):
if(s == e):
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if(divergence_idx == len(dims)):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s for s in
_get_minimal_slice_set(
start[divergence_idx + 1:],
[d - 1 for d in dims[divergence_idx + 1:]],
dims[divergence_idx + 1:],
start_edges=start_edges[divergence_idx + 1:],
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s for s in
_get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1:]],
end[divergence_idx + 1:],
dims[divergence_idx + 1:],
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
end_edges=end_edges[divergence_idx + 1:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif(start_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx]),)
)
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif(end_edges[divergence_idx]):
slices.extend(upper())
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if(middle_ground > 1):
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
)
slices.extend(lower())
return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
sliced_tensors = [t[s] for s in slices]
return torch.cat(
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
)
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
_out: Any = None,
_add_into_out: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t):
if(not low_mem):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
prepped_outputs = None
if(_out is not None):
reshape_fn = lambda t: t.view([-1] + list(t.shape[no_batch_dims:]))
prepped_outputs = tensor_tree_map(reshape_fn, _out)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (
flat_batch_dim % chunk_size != 0
)
i = 0
out = prepped_outputs
for _ in range(no_chunks):
# Chunk the input
if(not low_mem):
select_chunk = (
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
)
else:
select_chunk = (
partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims)
)
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if out is None:
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
if(_add_into_out):
v[i: i + chunk_size] += d2[k]
else:
v[i: i + chunk_size] = d2[k]
assign(out, output_chunk)
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
if(_add_into_out):
x1[i: i + chunk_size] += x2
else:
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
if(_add_into_out):
out[i: i + chunk_size] += output_chunk
else:
out[i: i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
i += chunk_size
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
class ChunkSizeTuner:
def __init__(self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=512,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
self.cached_arg_data = None
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
logging.info("Tuning chunk size...")
if(min_chunk_size >= self.max_chunk_size):
return min_chunk_size
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
candidates[-1] += 4
def test_chunk_size(chunk_size):
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
return True
except RuntimeError:
return False
min_viable_chunk_size_index = 0
i = len(candidates) - 1
while i > min_viable_chunk_size_index:
viable = test_chunk_size(candidates[i])
if(not viable):
i = (min_viable_chunk_size_index + i) // 2
else:
min_viable_chunk_size_index = i
i = (i + len(candidates) - 1) // 2
return candidates[min_viable_chunk_size_index]
def _compare_arg_caches(self, ac1, ac2):
consistent = True
for a1, a2 in zip(ac1, ac2):
assert(type(ac1) == type(ac2))
if(type(ac1) is list or type(ac1) is tuple):
consistent &= self._compare_arg_caches(a1, a2)
elif(type(ac1) is dict):
a1_items = [
v for _, v in sorted(a1.items(), key=lambda x: x[0])
]
a2_items = [
v for _, v in sorted(a2.items(), key=lambda x: x[0])
]
consistent &= self._compare_arg_caches(a1_items, a2_items)
else:
consistent &= a1 == a2
return consistent
def tune_chunk_size(self,
representative_fn: Callable,
args: Tuple[Any],
min_chunk_size: int,
) -> int:
consistent = True
remove_tensors = lambda a: a.shape if type(a) is torch.Tensor else a
arg_data = tree_map(remove_tensors, args, object)
if(self.cached_arg_data is not None):
# If args have changed shape/value, we need to re-tune
assert(len(self.cached_arg_data) == len(arg_data))
consistent = self._compare_arg_caches(
self.cached_arg_data, arg_data
)
else:
# Otherwise, we can reuse the precomputed value
consistent = False
if(not consistent):
self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn,
args,
min_chunk_size,
)
self.cached_arg_data = arg_data
return self.cached_chunk_size
......@@ -58,7 +58,8 @@ class ExponentialMovingAverage:
self._update_state_dict_(model.state_dict(), self.params)
def load_state_dict(self, state_dict: OrderedDict) -> None:
self.params = state_dict["params"]
for k in state_dict["params"].keys():
self.params[k] = state_dict["params"][k].clone()
self.decay = state_dict["decay"]
def state_dict(self) -> OrderedDict:
......
......@@ -22,7 +22,7 @@ from typing import Dict, Union
from openfold.np import protein
import openfold.np.residue_constants as rc
from openfold.utils.geometry import rigid_matrix_vector, rotation_matrix
from openfold.utils.geometry import rigid_matrix_vector, rotation_matrix, vector
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
batched_gather,
......@@ -188,13 +188,16 @@ def torsion_angles_to_frames(
aatype: torch.Tensor,
rrgdf: torch.Tensor,
):
rigid_type = Rigid if isinstance(r, Rigid) else rigid_matrix_vector.Rigid3Array
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_r = r.from_tensor_4x4(default_4x4)
default_r = rigid_type.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
......@@ -221,11 +224,9 @@ def torsion_angles_to_frames(
all_rots[..., 2, 1:] = alpha
if isinstance(r, Rigid):
rigid_type = Rigid
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
else:
rigid_type = rigid_matrix_vector.Rigid3Array
all_rots = rotation_matrix.Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
......@@ -291,4 +292,7 @@ def frames_and_literature_positions_to_atom14_pos(
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask
if isinstance(pred_positions, vector.Vec3Array):
return pred_positions.to_tensor()
return pred_positions
......@@ -16,7 +16,7 @@ class QuatRigid(nn.Module):
else:
rigid_dim = 6
self.linear = Linear(c_hidden, rigid_dim)
self.linear = Linear(c_hidden, rigid_dim, init="final")
def forward(self, activations: torch.Tensor) -> Rigid3Array:
# NOTE: During training, this needs to be run in higher precision
......
......@@ -59,14 +59,14 @@ class Param:
stacked: bool = False
def _process_translations_dict(d, top_layer=True):
def process_translation_dict(d, top_layer=True):
flat = {}
for k, v in d.items():
if type(v) == dict:
prefix = _NPZ_KEY_PREFIX if top_layer else ""
sub_flat = {
(prefix + "/".join([k, k_prime])): v_prime
for k_prime, v_prime in _process_translations_dict(
for k_prime, v_prime in process_translation_dict(
v, top_layer=False
).items()
}
......@@ -129,7 +129,7 @@ def assign(translation_dict, orig_weights):
raise
def get_translation_dict(model, version, is_multimer=False):
def generate_translation_dict(model, version, is_multimer=False):
#######################
# Some templates
#######################
......@@ -277,7 +277,7 @@ def get_translation_dict(model, version, is_multimer=False):
},
"v_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
ipa.linear_v.weight,
),
},
"q_point_projection": PointProjectionParams(
......@@ -388,11 +388,6 @@ def get_translation_dict(model, version, is_multimer=False):
############################
# translations dict overflow
############################
tps_blocks = model.template_embedder.template_pair_stack.blocks
tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks]
)
ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
......@@ -416,32 +411,10 @@ def get_translation_dict(model, version, is_multimer=False):
"pair_activiations": LinearParams(
model.input_embedder.linear_relpos
),
"template_embedding": {
"single_template_embedding": {
"embedding2d": LinearParams(
model.template_embedder.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(
model.template_embedder.template_pointwise_att.mha
),
},
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding": LinearParams(
model.template_embedder.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_embedder.template_angle_embedder.linear_2
),
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
......@@ -478,7 +451,6 @@ def get_translation_dict(model, version, is_multimer=False):
},
}
else:
temp_embedder = model.template_embedder
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
......@@ -497,53 +469,6 @@ def get_translation_dict(model, version, is_multimer=False):
model.input_embedder.linear_relpos
),
},
"template_embedding": {
"single_template_embedding": {
"query_embedding_norm": LayerNormParams(
temp_embedder.template_pair_embedder.query_embedding_layer_norm
),
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParamsMultimer(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_1
),
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParamsMultimer(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParamsMultimer(
temp_embedder.template_pair_embedder.y_linear
),
"template_pair_embedding_6": LinearParamsMultimer(
temp_embedder.template_pair_embedder.z_linear
),
"template_pair_embedding_7": LinearParamsMultimer(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
temp_embedder.template_pair_embedder.query_embedding_linear
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"output_linear": LinearParams(
temp_embedder.linear_t
),
},
"template_projection": LinearParams(
temp_embedder.template_single_embedder.template_projector,
),
"template_single_embedding": LinearParams(
temp_embedder.template_single_embedder.template_single_embedder,
),
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
......@@ -592,12 +517,88 @@ def get_translation_dict(model, version, is_multimer=False):
"model_4_ptm",
"model_5_ptm",
]
if version in no_templ:
evo_dict = translations["evoformer"]
keys = list(evo_dict.keys())
for k in keys:
if "template_" in k:
evo_dict.pop(k)
if version not in no_templ:
tps_blocks = model.template_embedder.template_pair_stack.blocks
tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks]
)
if (not is_multimer):
template_param_dict = {
"template_embedding": {
"single_template_embedding": {
"embedding2d": LinearParams(
model.template_embedder.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(model.template_embedder.template_pointwise_att.mha),
},
"template_single_embedding": LinearParams(
model.template_embedder.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_embedder.template_angle_embedder.linear_2
),
}
else:
temp_embedder = model.template_embedder
template_param_dict = {
"template_embedding": {
"single_template_embedding": {
"query_embedding_norm": LayerNormParams(
temp_embedder.template_pair_embedder.query_embedding_layer_norm
),
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParamsMultimer(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_1
),
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParamsMultimer(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParamsMultimer(
temp_embedder.template_pair_embedder.y_linear
),
"template_pair_embedding_6": LinearParamsMultimer(
temp_embedder.template_pair_embedder.z_linear
),
"template_pair_embedding_7": LinearParamsMultimer(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
temp_embedder.template_pair_embedder.query_embedding_linear
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"output_linear": LinearParams(
temp_embedder.linear_t
),
},
"template_projection": LinearParams(
temp_embedder.template_single_embedder.template_projector,
),
"template_single_embedding": LinearParams(
temp_embedder.template_single_embedder.template_single_embedder,
),
}
translations["evoformer"].update(template_param_dict)
if "_ptm" in version:
translations["predicted_aligned_error_head"] = {
......@@ -609,15 +610,10 @@ def get_translation_dict(model, version, is_multimer=False):
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(
model,
version,
is_multimer=("multimer" in version)
)
translations = generate_translation_dict(model, version, is_multimer=("multimer" in version))
# Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations)
flat = process_translation_dict(translations)
# Sanity check
keys = list(data.keys())
......
// Copyright 2021 AlQuraishi Laboratory
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
#include <torch/extension.h>
void attn_softmax_inplace_forward_(
at::Tensor input,
long long rows, int cols
)
{
throw std::runtime_error("attn_softmax_inplace_forward_ not implemented on CPU");
};
void attn_softmax_inplace_backward_(
at::Tensor output,
at::Tensor d_ov,
at::Tensor values,
long long rows,
int cols_output,
int cols_values
)
{
throw std::runtime_error("attn_softmax_inplace_backward_ not implemented on CPU");
};
\ No newline at end of file
......@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import operator
import time
import dllogger as logger
import numpy as np
import torch.cuda.profiler as profiler
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
import numpy as np
from pytorch_lightning import Callback
import torch.cuda.profiler as profiler
def is_main_process():
......
......@@ -43,9 +43,15 @@ def softmax_cross_entropy(logits, labels):
def sigmoid_cross_entropy(logits, labels):
log_p = torch.log(torch.sigmoid(logits))
log_not_p = torch.log(torch.sigmoid(-logits))
loss = -labels * log_p - (1 - labels) * log_not_p
logits_dtype = logits.dtype
logits = logits.double()
labels = labels.double()
log_p = torch.nn.functional.logsigmoid(logits)
# log_p = torch.log(torch.sigmoid(logits))
log_not_p = torch.nn.functional.logsigmoid(-1 * logits)
# log_not_p = torch.log(torch.sigmoid(-logits))
loss = (-1. * labels) * log_p - (1. - labels) * log_not_p
loss = loss.to(dtype=logits_dtype)
return loss
......@@ -658,10 +664,11 @@ def compute_tm(
denom = eps + torch.sum(pair_residue_weights, dim=-1, keepdims=True)
normed_residue_mask = pair_residue_weights / denom
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
idx = weighted.argmax(dim=-1, keepdim=True)
return torch.gather(per_alignment, -1, idx).squeeze(-1)
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]
def tm_loss(
logits,
......@@ -1483,17 +1490,17 @@ def experimentally_resolved_loss(
loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss = torch.sum(loss, dim=-1)
loss = loss * (
(resolution >= min_resolution) & (resolution <= max_resolution)
)
loss = torch.mean(loss)
return loss
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs):
"""
Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
......@@ -1505,7 +1512,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
Masked MSA loss
"""
errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes)
)
# FP16-friendly averaging. Equivalent to:
......@@ -1562,10 +1569,10 @@ class AlphaFoldLoss(nn.Module):
batch,
self.config.fape,
),
"lddt": lambda: lddt_loss(
"plddt_loss": lambda: lddt_loss(
logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.lddt},
**{**batch, **self.config.plddt_loss},
),
"masked_msa": lambda: masked_msa_loss(
logits=out["masked_msa_logits"],
......
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import torch
def is_fp16_enabled():
# Autocast world
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
return fp16_enabled
......@@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations
from functools import lru_cache
from typing import Tuple, Any, Sequence, Callable, Optional
import numpy as np
......@@ -34,51 +35,31 @@ def rot_matmul(
Returns:
The product ab
"""
row_1 = torch.stack(
[
a[..., 0, 0] * b[..., 0, 0]
+ a[..., 0, 1] * b[..., 1, 0]
+ a[..., 0, 2] * b[..., 2, 0],
a[..., 0, 0] * b[..., 0, 1]
+ a[..., 0, 1] * b[..., 1, 1]
+ a[..., 0, 2] * b[..., 2, 1],
a[..., 0, 0] * b[..., 0, 2]
+ a[..., 0, 1] * b[..., 1, 2]
+ a[..., 0, 2] * b[..., 2, 2],
],
dim=-1,
)
row_2 = torch.stack(
[
a[..., 1, 0] * b[..., 0, 0]
+ a[..., 1, 1] * b[..., 1, 0]
+ a[..., 1, 2] * b[..., 2, 0],
a[..., 1, 0] * b[..., 0, 1]
+ a[..., 1, 1] * b[..., 1, 1]
+ a[..., 1, 2] * b[..., 2, 1],
a[..., 1, 0] * b[..., 0, 2]
+ a[..., 1, 1] * b[..., 1, 2]
+ a[..., 1, 2] * b[..., 2, 2],
],
dim=-1,
)
row_3 = torch.stack(
def row_mul(i):
return torch.stack(
[
a[..., i, 0] * b[..., 0, 0]
+ a[..., i, 1] * b[..., 1, 0]
+ a[..., i, 2] * b[..., 2, 0],
a[..., i, 0] * b[..., 0, 1]
+ a[..., i, 1] * b[..., 1, 1]
+ a[..., i, 2] * b[..., 2, 1],
a[..., i, 0] * b[..., 0, 2]
+ a[..., i, 1] * b[..., 1, 2]
+ a[..., i, 2] * b[..., 2, 2],
],
dim=-1,
)
return torch.stack(
[
a[..., 2, 0] * b[..., 0, 0]
+ a[..., 2, 1] * b[..., 1, 0]
+ a[..., 2, 2] * b[..., 2, 0],
a[..., 2, 0] * b[..., 0, 1]
+ a[..., 2, 1] * b[..., 1, 1]
+ a[..., 2, 2] * b[..., 2, 1],
a[..., 2, 0] * b[..., 0, 2]
+ a[..., 2, 1] * b[..., 1, 2]
+ a[..., 2, 2] * b[..., 2, 2],
],
dim=-1,
row_mul(0),
row_mul(1),
row_mul(2),
],
dim=-2
)
return torch.stack([row_1, row_2, row_3], dim=-2)
def rot_vec_mul(
r: torch.Tensor,
......@@ -94,9 +75,7 @@ def rot_vec_mul(
Returns:
[*, 3] rotated coordinates
"""
x = t[..., 0]
y = t[..., 1]
z = t[..., 2]
x, y, z = torch.unbind(t, dim=-1)
return torch.stack(
[
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
......@@ -106,7 +85,7 @@ def rot_vec_mul(
dim=-1,
)
@lru_cache(maxsize=None)
def identity_rot_mats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
......@@ -118,10 +97,12 @@ def identity_rot_mats(
)
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
rots = rots.expand(*batch_dims, -1, -1)
rots = rots.contiguous()
return rots
@lru_cache(maxsize=None)
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
......@@ -137,6 +118,7 @@ def identity_trans(
return trans
@lru_cache(maxsize=None)
def identity_quats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
......@@ -196,7 +178,7 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
......@@ -251,10 +233,20 @@ _QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
_CACHED_QUATS = {
"_QTR_MAT": _QTR_MAT,
"_QUAT_MULTIPLY": _QUAT_MULTIPLY,
"_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC
}
@lru_cache(maxsize=None)
def _get_quat(quat_key, dtype, device):
return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
mat = quat1.new_tensor(_QUAT_MULTIPLY)
mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
......@@ -266,7 +258,7 @@ def quat_multiply(quat1, quat2):
def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
......
import json
import logging
import os
import re
import time
import numpy
import torch
from openfold.model.model import AlphaFold
from openfold.np import residue_constants, protein
from openfold.np.relax import relax
from openfold.utils.import_weights import (
import_jax_weights_,
)
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict
)
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
model_count = 0
if openfold_checkpoint_path:
model_count += len(openfold_checkpoint_path.split(","))
if jax_param_path:
model_count += len(jax_param_path.split(","))
return model_count
def get_model_basename(model_path):
return os.path.splitext(
os.path.basename(
os.path.normpath(model_path)
)
)[0]
def make_output_directory(output_dir, model_name, multiple_model_mode):
if multiple_model_mode:
prediction_dir = os.path.join(output_dir, "predictions", model_name)
else:
prediction_dir = os.path.join(output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
return prediction_dir
def load_models_from_command_line(config, model_device, openfold_checkpoint_path, jax_param_path, output_dir):
# Create the output directory
multiple_model_mode = count_models_to_evaluate(openfold_checkpoint_path, jax_param_path) > 1
if multiple_model_mode:
logger.info(f"evaluating multiple models")
if jax_param_path:
for path in jax_param_path.split(","):
model_basename = get_model_basename(path)
model_version = "_".join(model_basename.split("_")[1:])
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(
model, path, version=model_version
)
model = model.to(model_device)
logger.info(
f"Successfully loaded JAX parameters at {path}..."
)
output_directory = make_output_directory(output_dir, model_basename, multiple_model_mode)
yield model, output_directory
if openfold_checkpoint_path:
for path in openfold_checkpoint_path.split(","):
model = AlphaFold(config)
model = model.eval()
checkpoint_basename = get_model_basename(path)
if os.path.isdir(path):
# A DeepSpeed checkpoint
ckpt_path = os.path.join(
output_dir,
checkpoint_basename + ".pt",
)
if not os.path.isfile(ckpt_path):
convert_zero_checkpoint_to_fp32_state_dict(
path,
ckpt_path,
)
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
else:
ckpt_path = path
d = torch.load(ckpt_path)
if "ema" in d:
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
model = model.to(model_device)
logger.info(
f"Loaded OpenFold parameters at {path}..."
)
output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode)
yield model, output_directory
if not jax_param_path and not openfold_checkpoint_path:
raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
def parse_fasta(data):
data = re.sub('>$', '', data, flags=re.M)
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
return tags, seqs
def update_timings(timing_dict, output_file=os.path.join(os.getcwd(), "timings.json")):
"""
Write dictionary of one or more run step times to a file
"""
if os.path.exists(output_file):
with open(output_file, "r") as f:
try:
timings = json.load(f)
except json.JSONDecodeError:
logger.info(f"Overwriting non-standard JSON in {output_file}.")
timings = {}
else:
timings = {}
timings.update(timing_dict)
with open(output_file, "w") as f:
json.dump(timings, f)
return output_file
def run_model(model, batch, tag, output_dir):
with torch.no_grad():
# Temporarily disable templates if there aren't any in the batch
template_enabled = model.config.template.enabled
model.config.template.enabled = template_enabled and any([
"template_" in k for k in batch
])
logger.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
inference_time = time.perf_counter() - t
logger.info(f"Inference time: {inference_time}")
update_timings({"inference": inference_time}, os.path.join(output_dir, "timings.json"))
model.config.template.enabled = template_enabled
return out
def prep_output(out, batch, feature_dict, feature_processor, config_preset, multimer_ri_gap, subtract_plddt):
plddt = out["plddt"]
plddt_b_factors = numpy.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
if subtract_plddt:
plddt_b_factors = 100 - plddt_b_factors
# Prep protein metadata
template_domain_names = []
template_chain_index = None
if feature_processor.config.common.use_templates and "template_domain_names" in feature_dict:
template_domain_names = [
t.decode("utf-8") for t in feature_dict["template_domain_names"]
]
# This works because templates are not shuffled during inference
template_domain_names = template_domain_names[
:feature_processor.config.predict.max_templates
]
if "template_chain_index" in feature_dict:
template_chain_index = feature_dict["template_chain_index"]
template_chain_index = template_chain_index[
:feature_processor.config.predict.max_templates
]
no_recycling = feature_processor.config.common.max_recycling_iters
remark = ', '.join([
f"no_recycling={no_recycling}",
f"max_templates={feature_processor.config.predict.max_templates}",
f"config_preset={config_preset}",
])
# For multi-chain FASTAs
ri = feature_dict["residue_index"]
chain_index = (ri - numpy.arange(ri.shape[0])) / multimer_ri_gap
chain_index = chain_index.astype(numpy.int64)
cur_chain = 0
prev_chain_max = 0
for i, c in enumerate(chain_index):
if c != cur_chain:
cur_chain = c
prev_chain_max = i + cur_chain * multimer_ri_gap
batch["residue_index"][i] -= prev_chain_max
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors,
remove_leading_feature_dimension=not "multimer" in config_preset,
remark=remark,
parents=template_domain_names,
parents_chain_index=template_chain_index,
)
return unrelaxed_protein
def relax_protein(config, model_device, unrelaxed_protein, output_directory, output_name):
amber_relaxer = relax.AmberRelaxation(
use_gpu=(model_device != "cpu"),
**config.relax,
)
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if "cuda" in model_device:
device_no = model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
relaxation_time = time.perf_counter() - t
logger.info(f"Relaxation time: {relaxation_time}")
update_timings({"relaxation": relaxation_time}, os.path.join(output_directory, "timings.json"))
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
output_directory, f'{output_name}_relaxed.pdb'
)
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...")
\ No newline at end of file
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from Bio.SVDSuperimposer import SVDSuperimposer
import numpy as np
import torch
......
......@@ -14,9 +14,22 @@
# limitations under the License.
from functools import partial
import logging
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
import torch
import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
def add(m1, m2, inplace):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if(not inplace):
m1 = m1 + m2
else:
m1 += m2
return m1
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
......@@ -106,303 +119,3 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def _fetch_dims(tree):
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if(start_edges is None):
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if(end_edges is None):
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if(len(start) == 0):
return [tuple()]
elif(len(start) == 1):
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s,e in zip(start, end):
if(s == e):
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if(divergence_idx == len(dims)):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s for s in
_get_minimal_slice_set(
start[divergence_idx + 1:],
[d - 1 for d in dims[divergence_idx + 1:]],
dims[divergence_idx + 1:],
start_edges=start_edges[divergence_idx + 1:],
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s for s in
_get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1:]],
end[divergence_idx + 1:],
dims[divergence_idx + 1:],
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
end_edges=end_edges[divergence_idx + 1:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif(start_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx]),)
)
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif(end_edges[divergence_idx]):
slices.extend(upper())
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if(middle_ground > 1):
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
)
slices.extend(lower())
return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
sliced_tensors = [t[s] for s in slices]
return torch.cat(
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
)
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t):
# TODO: make this more memory efficient. This sucks
if(not low_mem):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (
flat_batch_dim % chunk_size != 0
)
i = 0
out = None
for _ in range(no_chunks):
# Chunk the input
if(not low_mem):
select_chunk = (
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
)
else:
select_chunk = (
partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims)
)
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if out is None:
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
out[i : i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
i += chunk_size
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from functools import partialmethod
import numpy as np
import torch
from openfold.utils.tensor_utils import tensor_tree_map
def pad_feature_dict_seq(feature_dict, seqlen):
""" Pads the sequence length of a feature dict. Used for tracing. """
# The real sequence length can't be longer than the desired one
true_n = feature_dict["aatype"].shape[-2]
assert(true_n <= seqlen)
new_feature_dict = {}
feat_seq_dims = {
"aatype": -2,
"between_segment_residues": -1,
"residue_index": -1,
"seq_length": -1,
"deletion_matrix_int": -1,
"msa": -1,
"num_alignments": -1,
"template_aatype": -2,
"template_all_atom_mask": -2,
"template_all_atom_positions": -3,
}
for k,v in feature_dict.items():
if(k not in feat_seq_dims):
new_feature_dict[k] = v
continue
seq_dim = feat_seq_dims[k]
padded_shape = list(v.shape)
padded_shape[seq_dim] = seqlen
new_value = np.zeros(padded_shape, dtype=v.dtype)
new_value[tuple(slice(0, s) for s in v.shape)] = v
new_feature_dict[k] = new_value
new_feature_dict["seq_length"][0] = seqlen
return new_feature_dict
def trace_model_(model, sample_input):
# Grab the inputs to the final recycling iteration
feats = tensor_tree_map(lambda t: t[..., -1], sample_input)
# Gather some metadata
n = feats["aatype"].shape[-1]
msa_depth = feats["true_msa"].shape[-2]
extra_msa_depth = feats["extra_msa"].shape[-2]
no_templates = feats["template_aatype"].shape[-2]
device = feats["aatype"].device
seq_mask = feats["seq_mask"].to(device)
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
extra_msa_mask = feats["extra_msa_mask"].to(device)
template_pair_mask = torch.stack([pair_mask] * no_templates, dim=-3)
# Create some fake representations with the correct shapes
m = torch.rand(msa_depth + 4, n, model.globals.c_m).to(device)
z = torch.rand(n, n, model.globals.c_z).to(device)
t = torch.rand(no_templates, n, n, model.globals.c_t).to(device)
a = torch.rand(extra_msa_depth, n, model.globals.c_e).to(device)
msa_mask = torch.randint(0, 1, (msa_depth + 4, n)).to(device)
# We need to do a dry run through the model so the chunk size tuners'
# trial runs (which run during the first-ever model iteration) aren't
# baked into the trace. There's no need to run the entire thing,
# though; we just need to run one block from each transformer stack.
evoformer_blocks = model.evoformer.blocks
model.evoformer.blocks = evoformer_blocks[:1]
extra_msa_blocks = model.extra_msa_stack.blocks
model.extra_msa_stack.blocks = extra_msa_blocks[:1]
if(model.template_config.enabled):
template_pair_stack_blocks = model.template_pair_stack.blocks
model.template_pair_stack.blocks = template_pair_stack_blocks[:1]
single_recycling_iter_input = tensor_tree_map(
lambda t: t[..., :1], sample_input,
)
with torch.no_grad():
_ = model(single_recycling_iter_input)
model.evoformer.blocks = evoformer_blocks
model.extra_msa_stack.blocks = extra_msa_blocks
del evoformer_blocks, extra_msa_blocks
if(model.template_config.enabled):
model.template_pair_stack.blocks = template_pair_stack_blocks
del template_pair_stack_blocks
def get_tuned_chunk_size(module):
tuner = module.chunk_size_tuner
chunk_size = tuner.cached_chunk_size
# After our trial run above, this should always be set
assert(chunk_size is not None)
return chunk_size
# Fetch the resulting chunk sizes
evoformer_chunk_size = model.globals.chunk_size
if(model.evoformer.chunk_size_tuner is not None):
evoformer_chunk_size = get_tuned_chunk_size(model.evoformer)
extra_msa_chunk_size = model.globals.chunk_size
if(model.extra_msa_stack.chunk_size_tuner is not None):
extra_msa_chunk_size = get_tuned_chunk_size(model.extra_msa_stack)
if(model.template_config.enabled):
template_pair_stack_chunk_size = model.globals.chunk_size
if(model.template_pair_stack.chunk_size_tuner is not None):
template_pair_stack_chunk_size = get_tuned_chunk_size(
model.template_pair_stack
)
def trace_block(block, block_inputs):
# Yes, yes, I know
with contextlib.redirect_stderr(None):
traced_block = torch.jit.trace(block, block_inputs)
traced_block = torch.jit.freeze(traced_block, optimize_numerics=True)
# It would be nice to use this, but its runtimes are extremely
# unpredictable
# traced_block = torch.jit.optimize_for_inference(traced_block)
# All trace inputs need to be tensors. This wrapper takes care of that
def traced_block_wrapper(*args, **kwargs):
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t
args = [to_tensor(a) for a in args]
kwargs = {k: to_tensor(v) for k,v in kwargs.items()}
return traced_block(*args, **kwargs)
return traced_block_wrapper
def verify_arg_order(fn, arg_list):
""" Because it's difficult to specify keyword arguments of Module
functions during tracing, we need to pass them as a tuple. As a
sanity check, we manually verify their order here.
"""
fn_arg_names = fn.__code__.co_varnames
# Remove the "self" parameter
assert(fn_arg_names[0] == "self")
fn_arg_names = fn_arg_names[1:]
# Trim unspecified arguments
fn_arg_names = fn_arg_names[:len(arg_list)]
name_tups = list(zip(fn_arg_names, [n for n, _ in arg_list]))
assert(all([n1 == n2 for n1, n2 in name_tups]))
evoformer_attn_chunk_size = max(
model.globals.chunk_size, evoformer_chunk_size // 4
)
# MSA row attention
msa_att_row_arg_tuples = [
("m", m),
("z", z),
("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_lma", torch.tensor(model.globals.use_lma)),
]
verify_arg_order(
model.evoformer.blocks[0].msa_att_row.forward,
msa_att_row_arg_tuples
)
msa_att_row_args = [arg for _, arg in msa_att_row_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.msa_att_row, msa_att_row_args
)
del b.msa_att_row
b.msa_att_row = traced_block
# MSA col attention
msa_att_col_arg_tuples = [
("m", m),
("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_chunk_size)),
("use_lma", torch.tensor(model.globals.use_lma)),
("use_flash", torch.tensor(model.globals.use_flash)),
]
verify_arg_order(
model.evoformer.blocks[0].msa_att_col.forward,
msa_att_col_arg_tuples
)
msa_att_col_args = [arg for _, arg in msa_att_col_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.msa_att_col, msa_att_col_args
)
del b.msa_att_col
b.msa_att_col = traced_block
# OPM
opm_arg_tuples = [
("m", m),
("mask", msa_mask.float()),
("chunk_size", torch.tensor(evoformer_chunk_size)),
("inplace_safe", torch.tensor(True)),
]
verify_arg_order(
model.evoformer.blocks[0].core.outer_product_mean.forward,
opm_arg_tuples
)
opm_args = [arg for _, arg in opm_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.core.outer_product_mean, opm_args
)
del b.core.outer_product_mean
b.core.outer_product_mean = traced_block
# Triangular multiplicative update (out)
tri_mul_out_arg_tuples = [
("z", z),
("mask", pair_mask.float()),
("inplace_safe", torch.tensor(True)),
("_add_with_inplace", torch.tensor(True)),
]
verify_arg_order(
model.evoformer.blocks[0].core.tri_mul_out.forward,
tri_mul_out_arg_tuples
)
tri_mul_out_args = [arg for _, arg in tri_mul_out_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.core.tri_mul_out, tri_mul_out_args
)
del b.core.tri_mul_out
b.core.tri_mul_out = traced_block
# Triangular multiplicative update (in)
tri_mul_in_arg_tuples = [
("z", z),
("mask", pair_mask.float()),
("inplace_safe", torch.tensor(True)),
("_add_with_inplace", torch.tensor(True)),
]
verify_arg_order(
model.evoformer.blocks[0].core.tri_mul_in.forward,
tri_mul_in_arg_tuples
)
tri_mul_in_args = [arg for _, arg in tri_mul_in_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.core.tri_mul_in, tri_mul_in_args
)
del b.core.tri_mul_in
b.core.tri_mul_in = traced_block
# Triangular attention (start)
tri_att_start_arg_tuples = [
("x", z),
("mask", pair_mask.float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)),
]
verify_arg_order(
model.evoformer.blocks[0].core.tri_att_start.forward,
tri_att_start_arg_tuples
)
tri_att_start_args = [arg for _, arg in tri_att_start_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.core.tri_att_start, tri_att_start_args
)
del b.core.tri_att_start
b.core.tri_att_start = traced_block
# Triangular attention (end)
tri_att_end_arg_tuples = [
("x", z.transpose(-2, -3)),
("mask", pair_mask.transpose(-1, -2).float()),
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)),
]
verify_arg_order(
model.evoformer.blocks[0].core.tri_att_end.forward,
tri_att_end_arg_tuples
)
tri_att_end_args = [arg for _, arg in tri_att_end_arg_tuples]
with torch.no_grad():
for b in model.evoformer.blocks:
traced_block = trace_block(
b.core.tri_att_end, tri_att_end_args
)
del b.core.tri_att_end
b.core.tri_att_end = traced_block
#evoformer_arg_tuples = [
# ("m", m),
# ("z", z),
# ("msa_mask", msa_mask),
# ("pair_mask", pair_mask),
# ("chunk_size", torch.tensor(evoformer_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("use_flash", torch.tensor(model.globals.use_flash)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(evoformer_attn_chunk_size)),
#]
#verify_arg_order(model.evoformer.blocks[0].forward, evoformer_arg_tuples)
#evoformer_args = [arg for _, arg in evoformer_arg_tuples]
#with torch.no_grad():
# traced_evoformer_stack = []
# for b in model.evoformer.blocks:
# traced_block = trace_block(b, evoformer_args)
# traced_evoformer_stack.append(traced_block)
#del model.evoformer.blocks
#model.evoformer.blocks = traced_evoformer_stack
# with torch.no_grad():
# for b in model.evoformer.blocks:
# _ = b(*evoformer_args)
#
# with torch.no_grad():
# for b in model.evoformer.blocks:
# _ = b(*evoformer_args)
# extra_msa_attn_chunk_size = max(
# model.globals.chunk_size, extra_msa_chunk_size // 4
# )
# extra_msa_arg_tuples = [
# ("m", a),
# ("z", z),
# ("msa_mask", extra_msa_mask),
# ("pair_mask", pair_mask),
# ("chunk_size", torch.tensor(extra_msa_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(extra_msa_attn_chunk_size)),
# ]
# verify_arg_order(
# model.extra_msa_stack.blocks[0].forward, extra_msa_arg_tuples
# )
# extra_msa_args = [arg for _, arg in extra_msa_arg_tuples]
# with torch.no_grad():
# traced_extra_msa_stack = []
# for b in model.extra_msa_stack.blocks:
# traced_block = trace_block(b, extra_msa_args)
# traced_extra_msa_stack.append(traced_block)
#
# del model.extra_msa_stack.blocks
# model.extra_msa_stack.blocks = traced_extra_msa_stack
# if(model.template_config.enabled):
# template_pair_stack_attn_chunk_size = max(
# model.globals.chunk_size, template_pair_stack_chunk_size // 4
# )
# template_pair_stack_arg_tuples = [
# ("z", t),
# ("mask", template_pair_mask),
# ("chunk_size", torch.tensor(template_pair_stack_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(
# template_pair_stack_attn_chunk_size
# )),
# ]
# verify_arg_order(
# model.template_pair_stack.blocks[0].forward,
# template_pair_stack_arg_tuples
# )
# template_pair_stack_args = [
# arg for _, arg in template_pair_stack_arg_tuples
# ]
#
# with torch.no_grad():
# traced_template_pair_stack = []
# for b in model.template_pair_stack.blocks:
# traced_block = trace_block(b, template_pair_stack_args)
# traced_template_pair_stack.append(traced_block)
#
# del model.template_pair_stack.blocks
# model.template_pair_stack.blocks = traced_template_pair_stack
# We need to do another dry run after tracing to allow the model to reach
# top speeds. Why, I don't know.
two_recycling_iter_input = tensor_tree_map(
lambda t: t[..., :2], sample_input,
)
with torch.no_grad():
_ = model(two_recycling_iter_input)
......@@ -12,49 +12,158 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from datetime import date
import logging
import math
import numpy as np
import os
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
update_timings, relax_protein
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
import pickle
import random
import sys
import time
import torch
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12)
):
# Gives a large speedup on Ampere-class GPUs
torch.set_float32_matmul_precision("high")
torch.set_grad_enabled(False)
from openfold.config import model_config
from openfold.data import (
data_pipeline,
feature_pipeline,
templates,
)
from openfold.data.tools import hhsearch, hmmsearch
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax
from openfold.utils.import_weights import (
import_jax_weights_,
)
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
from openfold.utils.trace_utils import (
pad_feature_dict_seq,
trace_model_,
)
from scripts.utils import add_data_args
TRACING_INTERVAL = 50
def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
for tag, seq in zip(tags, seqs):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
if is_multimer:
local_alignment_dir = alignment_dir
else:
local_alignment_dir = os.path.join(
alignment_dir,
os.path.join(alignment_dir, tag),
)
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
logger.info(f"Generating alignments for {tag}...")
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
no_cpus=args.cpus,
)
alignment_runner.run(
tmp_fasta_path, local_alignment_dir
)
else:
logger.info(
f"Using precomputed alignments for {tag} at {alignment_dir}..."
)
# Remove temporary FASTA file
os.remove(tmp_fasta_path)
def round_up_seqlen(seqlen):
return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL
def generate_feature_dict(
tags,
seqs,
alignment_dir,
data_processor,
args,
):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
if len(seqs) == 1:
tag = tags[0]
seq = seqs[0]
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(alignment_dir, tag)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
)
elif "multimer" in args.config_preset:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
)
else:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_multiseq_fasta(
fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
)
# Remove temporary FASTA file
os.remove(tmp_fasta_path)
return feature_dict
def list_files_with_extensions(dir, extensions):
return [f for f in os.listdir(dir) if f.endswith(extensions)]
def main(args):
config = model_config(args.model_name)
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(model, args.param_path, version=args.model_name)
#script_preset_(model)
model = model.to(args.model_device)
# Create the output directory
os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
is_multimer = "multimer" in args.model_name
if (args.trace_model):
if (not config.data.predict.fixed_size):
raise ValueError(
"Tracing requires that fixed_size mode be enabled in the config"
)
is_multimer = "multimer" in args.config_preset
if(is_multimer):
if(not args.use_precomputed_alignments):
......@@ -120,151 +229,150 @@ def main(args):
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
random_seed = random.randrange(2**32)
feature_processor = feature_pipeline.FeaturePipeline(
config.data
)
np.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if(not args.use_precomputed_alignments):
if args.use_precomputed_alignments is None:
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
for fasta_path in os.listdir(args.fasta_dir):
if(not ".fasta" == os.path.splitext(fasta_path)[-1]):
print(f"Skipping {fasta_path}. Not a .fasta file...")
continue
fasta_path = os.path.join(args.fasta_dir, fasta_path)
tag_list = []
seq_list = []
for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
# Gather input sequences
fasta_path = os.path.join(args.fasta_dir, fasta_file)
with open(fasta_path, "r") as fp:
data = fp.read()
tags, seqs = parse_fasta(data)
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
if((not is_multimer) and len(tags) != 1):
if ((not is_multimer) and len(tags) != 1):
print(
f"{fasta_path} contains more than one sequence but "
f"multimer mode is not enabled. Skipping..."
)
continue
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
tag_list.append((tag, tags))
seq_list.append(seqs)
seq_sort_fn = lambda target: sum([len(s) for s in target[1]])
sorted_targets = sorted(zip(tag_list, seq_list), key=seq_sort_fn)
feature_dicts = {}
model_generator = load_models_from_command_line(
config,
args.model_device,
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
for model, output_directory in model_generator:
cur_tracing_interval = 0
for (tag, tags), seqs in sorted_targets:
output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
# Does nothing if the alignments have already been computed
precompute_alignments(tags, seqs, alignment_dir, args, is_multimer)
for tag, seq in zip(tags, seqs):
tag, seq = tags[0], seqs[0]
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner.run(
fasta_path, local_alignment_dir
feature_dict = feature_dicts.get(tag, None)
if(feature_dict is None):
feature_dict = generate_feature_dict(
tags,
seqs,
alignment_dir,
data_processor,
args,
)
if(is_multimer):
local_alignment_dir = alignment_dir
else:
local_alignment_dir = os.path.join(
alignment_dir,
tags[0],
)
feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir
)
if(args.trace_model):
n = feature_dict["aatype"].shape[-2]
rounded_seqlen = round_up_seqlen(n)
feature_dict = pad_feature_dict_seq(
feature_dict, rounded_seqlen,
)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=is_multimer,
)
logging.info("Executing model...")
batch = processed_feature_dict
with torch.no_grad():
batch = {
feature_dicts[tag] = feature_dict
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=is_multimer
)
processed_feature_dict = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items()
for k,v in processed_feature_dict.items()
}
t = time.perf_counter()
chunk_size = model.globals.chunk_size
try:
model.globals.chunk_size = None
out = model(batch)
except RuntimeError as e:
model.globals.chunk_size = chunk_size
out = model(batch)
logging.info(f"Inference time: {time.perf_counter() - t}")
if (args.trace_model):
if (rounded_seqlen > cur_tracing_interval):
logger.info(
f"Tracing model at {rounded_seqlen} residues..."
)
t = time.perf_counter()
trace_model_(model, processed_feature_dict)
tracing_time = time.perf_counter() - t
logger.info(
f"Tracing time: {tracing_time}"
)
cur_tracing_interval = rounded_seqlen
out = run_model(model, processed_feature_dict, tag, args.output_dir)
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
processed_feature_dict = tensor_tree_map(
lambda x: np.array(x[..., -1].cpu()),
processed_feature_dict
)
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors,
remove_leading_feature_dimension=not is_multimer,
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
unrelaxed_protein = prep_output(
out,
processed_feature_dict,
feature_dict,
feature_processor,
args.config_preset,
args.multimer_ri_gap,
args.subtract_plddt
)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_unrelaxed.pdb'
output_directory, f'{output_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
print(unrelaxed_output_path)
print("asdjfh klasjdhf lkasjdhf lkjasdhflkjasdh fkl jasdhfklj hasdkljf hasldkjfh lkasjdfh lkajsdhflk asd")
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if("cuda" in args.model_device):
device_no = args.model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
logging.info(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb'
)
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
logger.info(f"Output written to {unrelaxed_output_path}...")
if not args.skip_relaxation:
# Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name)
if(args.save_outputs):
if args.save_outputs:
output_dict_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl'
output_directory, f'{output_name}_output_dict.pkl'
)
with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Model output written to {output_dict_path}...")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"fasta_dir", type=str,
help="Path to directory containing FASTA files, one sequence per file"
)
parser.add_argument(
"template_mmcif_dir", type=str,
......@@ -284,18 +392,22 @@ if __name__ == "__main__":
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser.add_argument(
"--model_name", type=str, default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
"--config_preset", type=str, default="model_1",
help="""Name of a model config preset defined in openfold/config.py"""
)
parser.add_argument(
"--jax_param_path", type=str, default=None,
help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser.add_argument(
"--param_path", type=str, default=None,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params"""
"--openfold_checkpoint_path", type=str, default=None,
help="""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser.add_argument(
"--save_outputs", type=bool, default=False,
"--save_outputs", action="store_true", default=False,
help="Whether to save all model outputs, including embeddings, etc."
)
parser.add_argument(
......@@ -303,19 +415,45 @@ if __name__ == "__main__":
help="""Number of CPUs with which to run alignment tools"""
)
parser.add_argument(
'--preset', type=str, default='full_dbs',
"--preset", type=str, default='full_dbs',
choices=('reduced_dbs', 'full_dbs')
)
parser.add_argument(
'--data_random_seed', type=str, default=None
"--output_postfix", type=str, default=None,
help="""Postfix for output prediction filenames"""
)
parser.add_argument(
"--data_random_seed", type=str, default=None
)
parser.add_argument(
"--skip_relaxation", action="store_true", default=False,
)
parser.add_argument(
"--multimer_ri_gap", type=int, default=200,
help="""Residue index offset between multiple sequences, if provided"""
)
parser.add_argument(
"--trace_model", action="store_true", default=False,
help="""Whether to convert parts of each model to TorchScript.
Significantly improves runtime at the cost of lengthy
'compilation.' Useful for large batch jobs."""
)
parser.add_argument(
"--subtract_plddt", action="store_true", default=False,
help=""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself"""
)
parser.add_argument(
"--long_sequence_inference", action="store_true", default=False,
help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
)
add_data_args(parser)
args = parser.parse_args()
if(args.param_path is None):
args.param_path = os.path.join(
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.model_name + ".npz"
"params_" + args.config_preset + ".npz"
)
if(args.model_device == "cpu" and torch.cuda.is_available()):
......
import argparse
import json
import os
def main(args):
db_path = os.path.join(args.output_db_path, f"{args.output_db_name}.db")
index_path = os.path.join(
args.output_db_path, f"{args.output_db_name}.index"
)
db_fp = open(db_path, "wb")
index = {}
db_offset = 0
for chain_alignment_dir in os.listdir(args.alignment_dir):
cad_path = os.path.join(args.alignment_dir, chain_alignment_dir)
for f in os.listdir(cad_path):
f_path = os.path.join(cad_path, f)
with open(f_path, "rb") as fp:
file_bytes = fp.read()
l = len(file_bytes)
file_list = index.setdefault(chain_alignment_dir, [])
file_list.append((f, db_offset, l))
db_fp.write(file_bytes)
db_offset += l
db_fp.close()
with open(index_path, "w") as fp:
json.dump(index, fp)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"alignment_dir", type=str,
help="""Path to precomputed alignment directory, with one subdirectory
per chain."""
)
parser.add_argument("output_db_path", type=str)
parser.add_argument("output_db_name", type=str)
args = parser.parse_args()
main(args)
import argparse
import json
import os
""" Unifies databases created with create_alignment_db.py """
def main(args):
super_index = {}
for f in os.listdir(args.alignment_db_dir):
if(not os.path.splitext(f)[-1] == ".index"):
continue
with open(os.path.join(args.alignment_db_dir, f), "r") as fp:
index = json.load(fp)
db_name = f"{os.path.splitext(f)[0]}.db"
for k in index:
super_index[k] = {
"db": db_name,
"files": index[k],
}
with open(os.path.join(args.output_dir, "super.index"), "w") as fp:
json.dump(super_index, fp)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("alignment_db_dir", type=str, help="Path to directory containing alignment_dbs")
parser.add_argument("output_dir", type=str, help="Path in which to output super index")
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment