Commit 3ad85a19 authored by Sam DeLuca's avatar Sam DeLuca
Browse files

Merge remote-tracking branch 'cyrus/main' into run-multiple-models

parents 43b8c6f9 6da2cdaf
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import logging
import math
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
import torch
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
)
def _fetch_dims(tree):
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if(start_edges is None):
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if(end_edges is None):
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if(len(start) == 0):
return [tuple()]
elif(len(start) == 1):
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s,e in zip(start, end):
if(s == e):
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if(divergence_idx == len(dims)):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s for s in
_get_minimal_slice_set(
start[divergence_idx + 1:],
[d - 1 for d in dims[divergence_idx + 1:]],
dims[divergence_idx + 1:],
start_edges=start_edges[divergence_idx + 1:],
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s for s in
_get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1:]],
end[divergence_idx + 1:],
dims[divergence_idx + 1:],
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
end_edges=end_edges[divergence_idx + 1:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif(start_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx]),)
)
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif(end_edges[divergence_idx]):
slices.extend(upper())
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if(middle_ground > 1):
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
)
slices.extend(lower())
return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
sliced_tensors = [t[s] for s in slices]
return torch.cat(
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
)
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
_out: Any = None,
_add_into_out: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t):
if(not low_mem):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
prepped_outputs = None
if(_out is not None):
reshape_fn = lambda t: t.view([-1] + list(t.shape[no_batch_dims:]))
prepped_outputs = tensor_tree_map(reshape_fn, _out)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (
flat_batch_dim % chunk_size != 0
)
i = 0
out = prepped_outputs
for _ in range(no_chunks):
# Chunk the input
if(not low_mem):
select_chunk = (
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
)
else:
select_chunk = (
partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims)
)
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if out is None:
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
if(_add_into_out):
v[i: i + chunk_size] += d2[k]
else:
v[i: i + chunk_size] = d2[k]
assign(out, output_chunk)
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
if(_add_into_out):
x1[i: i + chunk_size] += x2
else:
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
if(_add_into_out):
out[i: i + chunk_size] += output_chunk
else:
out[i: i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
i += chunk_size
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
class ChunkSizeTuner:
def __init__(self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=256,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
self.cached_arg_data = None
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
logging.info("Tuning chunk size...")
if(min_chunk_size >= self.max_chunk_size):
return min_chunk_size
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
def test_chunk_size(chunk_size):
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
return True
except RuntimeError:
return False
min_viable_chunk_size_index = 0
i = len(candidates) - 1
while i > min_viable_chunk_size_index:
viable = test_chunk_size(candidates[i])
if(not viable):
i = (min_viable_chunk_size_index + i) // 2
else:
min_viable_chunk_size_index = i
i = (i + len(candidates) - 1) // 2
return candidates[min_viable_chunk_size_index]
def _compare_arg_caches(self, ac1, ac2):
consistent = True
for a1, a2 in zip(ac1, ac2):
assert(type(ac1) == type(ac2))
if(type(ac1) is list or type(ac1) is tuple):
consistent &= self._compare_arg_caches(a1, a2)
elif(type(ac1) is dict):
a1_items = [
v for _, v in sorted(a1.items(), key=lambda x: x[0])
]
a2_items = [
v for _, v in sorted(a2.items(), key=lambda x: x[0])
]
consistent &= self._compare_arg_caches(a1_items, a2_items)
else:
consistent &= a1 == a2
return consistent
def tune_chunk_size(self,
representative_fn: Callable,
args: Tuple[Any],
min_chunk_size: int,
) -> int:
consistent = True
remove_tensors = lambda a: a.shape if type(a) is torch.Tensor else a
arg_data = tree_map(remove_tensors, args, object)
if(self.cached_arg_data is not None):
# If args have changed shape/value, we need to re-tune
assert(len(self.cached_arg_data) == len(arg_data))
consistent = self._compare_arg_caches(
self.cached_arg_data, arg_data
)
else:
# Otherwise, we can reuse the precomputed value
consistent = False
if(not consistent):
self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn,
args,
min_chunk_size,
)
self.cached_arg_data = arg_data
return self.cached_chunk_size
......@@ -15,10 +15,10 @@
from functools import partial
import logging
import math
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
import torch
import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
def add(m1, m2, inplace):
......@@ -119,374 +119,3 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def _fetch_dims(tree):
shapes = []
tree_type = type(tree)
if tree_type is dict:
for v in tree.values():
shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple:
for t in tree:
shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor:
shapes.append(tree.shape)
else:
raise ValueError("Not supported")
return shapes
@torch.jit.ignore
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
@torch.jit.ignore
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if(start_edges is None):
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if(end_edges is None):
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if(len(start) == 0):
return [tuple()]
elif(len(start) == 1):
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s,e in zip(start, end):
if(s == e):
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if(divergence_idx == len(dims)):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s for s in
_get_minimal_slice_set(
start[divergence_idx + 1:],
[d - 1 for d in dims[divergence_idx + 1:]],
dims[divergence_idx + 1:],
start_edges=start_edges[divergence_idx + 1:],
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s for s in
_get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1:]],
end[divergence_idx + 1:],
dims[divergence_idx + 1:],
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
end_edges=end_edges[divergence_idx + 1:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif(start_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx]),)
)
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif(end_edges[divergence_idx]):
slices.extend(upper())
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if(middle_ground > 1):
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
)
slices.extend(lower())
return [tuple(s) for s in slices]
@torch.jit.ignore
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
sliced_tensors = [t[s] for s in slices]
return torch.cat(
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
)
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if not (len(inputs) > 0):
raise ValueError("Must provide at least one input")
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t):
# TODO: make this more memory efficient. This sucks
if(not low_mem):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (
flat_batch_dim % chunk_size != 0
)
i = 0
out = None
for _ in range(no_chunks):
# Chunk the input
if(not low_mem):
select_chunk = (
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
)
else:
select_chunk = (
partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims)
)
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
# Allocate space for the output
if out is None:
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
out_type = type(output_chunk)
if out_type is dict:
def assign(d1, d2):
for k, v in d1.items():
if type(v) is dict:
assign(v, d2[k])
else:
v[i : i + chunk_size] = d2[k]
assign(out, output_chunk)
elif out_type is tuple:
for x1, x2 in zip(out, output_chunk):
x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor:
out[i : i + chunk_size] = output_chunk
else:
raise ValueError("Not supported")
i += chunk_size
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
class ChunkSizeTuner:
def __init__(self,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=256,
):
self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None
self.cached_arg_data = None
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
logging.info("Tuning chunk size...")
if(min_chunk_size >= self.max_chunk_size):
return min_chunk_size
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates
def test_chunk_size(chunk_size):
try:
with torch.no_grad():
fn(*args, chunk_size=chunk_size)
return True
except RuntimeError:
return False
min_viable_chunk_size_index = 0
i = len(candidates) - 1
while i > min_viable_chunk_size_index:
viable = test_chunk_size(candidates[i])
if(not viable):
i = (min_viable_chunk_size_index + i) // 2
else:
min_viable_chunk_size_index = i
i = (i + len(candidates) - 1) // 2
return candidates[min_viable_chunk_size_index]
def tune_chunk_size(self,
representative_fn: Callable,
args: Tuple[Any],
min_chunk_size: int,
) -> int:
consistent = True
arg_data = [
arg if type(arg) != torch.Tensor else arg.shape for arg in args
]
if(self.cached_arg_data is not None):
# If args have changed shape/value, we need to re-tune
assert(len(self.cached_arg_data) == len(args))
arg_data_iter = zip(self.cached_arg_data, arg_data)
for cached_arg_data, arg_data in arg_data_iter:
assert(type(cached_arg_data) == type(arg_data))
consistent = cached_arg_data == arg_data
else:
# Otherwise, we can reuse the precomputed value
consistent = False
if(not consistent):
self.cached_chunk_size = self._determine_favorable_chunk_size(
representative_fn,
args,
min_chunk_size,
)
self.cached_arg_data = arg_data
return self.cached_chunk_size
......@@ -46,6 +46,11 @@ from openfold.utils.tensor_utils import (
from scripts.utils import add_data_args
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
def precompute_alignments(tags, seqs, alignment_dir, args):
for tag, seq in zip(tags, seqs):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
......@@ -54,11 +59,10 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
logging.info(f"Generating alignments for {tag}...")
logger.info(f"Generating alignments for {tag}...")
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
use_small_bfd=(args.bfd_database_path is None)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
......@@ -68,7 +72,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(
......@@ -80,7 +83,6 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
def run_model(model, batch, tag, args):
logging.info("Executing model...")
with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.model_device)
......@@ -88,14 +90,14 @@ def run_model(model, batch, tag, args):
}
# Disable templates if there aren't any in the batch
model.config.template.enabled = any([
model.config.template.enabled = model.config.template.enabled and any([
"template_" in k for k in batch
])
logging.info(f"Running inference for {tag}...")
logger.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
logging.info(f"Inference time: {time.perf_counter() - t}")
logger.info(f"Inference time: {time.perf_counter() - t}")
return out
......@@ -131,7 +133,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
remark = ', '.join([
f"no_recycling={no_recycling}",
f"max_templates={feature_processor.config.predict.max_templates}",
f"config_preset={args.model_name}",
f"config_preset={args.config_preset}",
])
# For multi-chain FASTAs
......@@ -160,7 +162,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
return unrelaxed_protein
def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature_processor):
def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature_processor, prediction_dir):
with open(os.path.join(fasta_dir, fasta_file), "r") as fp:
data = fp.read()
......@@ -171,9 +173,21 @@ def generate_batch(fasta_file, fasta_dir, alignment_dir, data_processor, feature
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
output_name = f'{tag}_{args.config_preset}'
if args.output_postfix is not None:
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
if os.path.exists(unrelaxed_output_path):
return
precompute_alignments(tags, seqs, alignment_dir, args)
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
......@@ -215,6 +229,9 @@ def load_models_from_command_line(args, config):
model, path, version=args.model_name
)
model = model.to(args.model_device)
logger.info(
f"Successfully loaded JAX parameters at {args.jax_param_path}..."
)
yield model, None
if args.openfold_checkpoint_path:
for path in args.openfold_checkpoint_path.split(","):
......@@ -222,6 +239,7 @@ def load_models_from_command_line(args, config):
model = model.eval()
checkpoint_basename = None
if os.path.isdir(path):
# A DeepSpeed checkpoint
checkpoint_basename = os.path.splitext(
os.path.basename(
os.path.normpath(path)
......@@ -237,12 +255,20 @@ def load_models_from_command_line(args, config):
path,
ckpt_path,
)
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
else:
ckpt_path = path
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
if ("ema" in d):
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
model = model.to(args.model_device)
logger.info(
f"Loaded OpenFold parameters at {args.openfold_checkpoint_path}..."
)
yield model, checkpoint_basename
if not args.jax_param_path and not args.openfold_checkpoint_path:
raise ValueError(
......@@ -255,7 +281,7 @@ def main(args):
# Create the output directory
os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.model_name)
config = model_config(args.config_preset)
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
......@@ -280,13 +306,20 @@ def main(args):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
logger.info(f"Using precomputed alignments at {alignment_dir}...")
prediction_dir = os.path.join(args.output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in os.listdir(args.fasta_dir):
batch, tag, feature_dict = generate_batch(fasta_file, args.fasta_dir, alignment_dir, data_processor, feature_processor)
batch, tag, feature_dict = generate_batch(
fasta_file,
args.fasta_dir,
alignment_dir,
data_processor,
feature_processor,
prediction_dir)
for model, model_version in load_models_from_command_line(args, config):
......@@ -315,6 +348,7 @@ def main(args):
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
logger.info(f"Output written to {unrelaxed_output_path}...")
if not args.skip_relaxation:
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
......@@ -322,6 +356,7 @@ def main(args):
)
# Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", default="")
if "cuda" in args.model_device:
......@@ -329,7 +364,7 @@ def main(args):
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
logging.info(f"Relaxation time: {time.perf_counter() - t}")
logger.info(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
......@@ -337,6 +372,7 @@ def main(args):
)
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
logger.info(f"Relaxed output written to {relaxed_output_path}...")
if args.save_outputs:
output_dict_path = os.path.join(
......@@ -345,6 +381,7 @@ def main(args):
with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Model output written to {output_dict_path}...")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......@@ -370,7 +407,7 @@ if __name__ == "__main__":
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser.add_argument(
"--model_name", type=str, default="model_1",
"--config_preset", type=str, default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
......@@ -417,7 +454,7 @@ if __name__ == "__main__":
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.model_name + ".npz"
"params_" + args.config_preset + ".npz"
)
if(args.model_device == "cpu" and torch.cuda.is_available()):
......
......@@ -39,13 +39,15 @@ fi
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
echo "Downloading AlphaFold parameters..."
bash "${SCRIPT_DIR}/download_alphafold_params.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB70..."
bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB mmCIF files..."
bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}"
if [[ -d openfold/resources/params ]]; then
ln -s openfold/resources/params "${DOWNLOAD_DIR}/params"
ln -s openfold/resources/openfold_params "${DOWNLOAD_DIR}/params"
fi
echo "All data downloaded."
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips OpenFold parameters.
#
# Usage: bash download_openfold_params.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
FILE_ID="1OpeMrfWEUSD_KqffbPqd5p7WsJjlC3ZE"
FILENAME="openfold_params_06_22.tar.gz"
download_from_gdrive() {
FILE_ID="$1"
OUT_DIR="$2"
MSG=$(wget \
--quiet \
--save-cookies /tmp/cookies_$$.txt \
--keep-session-cookies \
--no-check-certificate \
"https://docs.google.com/uc?export=download&id=${FILE_ID}" \
-O- \
)
CONFIRM=$(echo $MSG | sed -rn "s/.*confirm=([0-9A-Za-z_]+).*/\1\n/p")
FILENAME=$(echo $MSG | sed -e "s/.*<a href=\"\/open?id=${FILE_ID}\">\(.*\)<\/a> (.*/\1/")
FILEPATH="${OUT_DIR}/${FILENAME}"
wget \
--quiet \
--load-cookies /tmp/cookies_$$.txt \
"https://docs.google.com/uc?export=download&confirm=${CONFIRM}&id=${FILE_ID}" \
-O "${FILEPATH}"
rm /tmp/cookies_$$.txt
echo $FILEPATH
}
DOWNLOAD_DIR="$1"
mkdir -p "${DOWNLOAD_DIR}"
DOWNLOAD_PATH=$(download_from_gdrive $FILE_ID "${DOWNLOAD_DIR}")
DOWNLOAD_FILENAME=$(basename "${DOWNLOAD_PATH}")
if [[ $FILENAME != $DOWNLOAD_FILENAME ]]; then
echo "Error: Downloaded filename ${DOWNLOAD_FILENAME} does not match expected filename ${FILENAME}"
rm "${DOWNLOAD_PATH}"
exit
fi
tar --extract --verbose --file="${DOWNLOAD_PATH}" \
--directory="${DOWNLOAD_DIR}" --preserve-permissions
rm "${DOWNLOAD_PATH}"
......@@ -31,8 +31,11 @@ wget -q -P openfold/resources \
mkdir -p tests/test_data/alphafold/common
ln -rs openfold/resources/stereo_chemical_props.txt tests/test_data/alphafold/common
# Download pretrained openfold weights
scripts/download_alphafold_params.sh openfold/resources
echo "Downloading OpenFold parameters..."
bash scripts/download_openfold_params.sh openfold/resources
echo "Downloading AlphaFold parameters..."
bash scripts/download_alphafold_params.sh openfold/resources
# Decompress test data
gunzip tests/test_data/sample_feats.pickle.gz
#!/bin/bash
# Generates uniclust30 all-against-all alignments on a SLURM cluster.
# Thanks to Milot Mirdita for help & feedback on this script.
set -e
if [[ $# != 3 ]]; then
echo "usage: ./run_uniclust30_search.sh <uniclust30_path> <scratch_dir> <out_dir>"
exit
fi
UNICLUST_PATH=$1
SCRATCH_DIR_BN=$2
OUT_DIR=$3
CPUS_PER_TASK=4
MAX_SIZE=10000000000 # 10GB
SCRATCH_DIR="${SCRATCH_DIR_BN}_${SLURM_NODEID}"
mkdir -p ${SCRATCH_DIR}
mkdir -p ${OUT_DIR}
# copy database to local ssd
DB_BN=$(basename $UNICLUST_PATH)
DB_DIR="/dev/shm/uniclust30"
mkdir -p $DB_DIR
cp ${UNICLUST_PATH}*.ff* $DB_DIR
DB="${DB_DIR}/${DB_BN}"
for f in $(ls $OUT_DIR/*.zip)
do
zipinfo -1 $f '*/' | awk -F/ '{print $(NF-1)}' >> ${DB_DIR}/already_searched.txt
done
python3 filter_ffindex.py ${DB}_a3m.ffindex ${DB_DIR}/already_searched.txt ${DB_DIR}/filtered_a3m.ffindex
TARGET="${DB}_a3m_${SLURM_NODEID}.ffindex"
split -n "l/$((SLURM_NODEID + 1))/${SLURM_JOB_NUM_NODES}" "${DB_DIR}/filtered_a3m.ffindex" > $TARGET
open_sem() {
mkfifo pipe-$$
exec 3<>pipe-$$
rm pipe-$$
local i=$1
for ((;i>0;i--)); do
printf %s 000 >&3
done
}
# run the given command asynchronously and pop/push tokens
run_with_lock() {
local x
# this read waits until there is something to read
read -u 3 -n 3 x && ((0==x)) || exit $x
(
( "$@"; )
# push the return code of the command to the semaphore
printf '%.3d' $? >&3
)&
}
task() {
dd if="${DB}_a3m.ffdata" ibs=1 skip="${OFF}" count="${LEN}" status=none | \
hhblits -i stdin \
-oa3m "${SCRATCH_DIR}/${KEY}/uniclust30.a3m" \
-v 0 \
-o /dev/null \
-cpu $CPUS_PER_TASK \
-d $DB \
-n 3 \
-e 0.001
}
zip_or_not() {
SIZE=$(du -hbs $SCRATCH_DIR | sed 's/|/ /' | awk '{print $1}')
#if [[ "$SIZE" -gt "$MAX_SIZE" ]]
if [[ "2" -gt "1" ]]
then
wait
RANDOM_NAME=$(cat /dev/urandom | tr -cd 'a-f0-9' | head -c 32)
zip -r "${OUT_DIR}/${RANDOM_NAME}.zip" $SCRATCH_DIR
find $SCRATCH_DIR -mindepth 1 -type d -exec rm -rf {} +
fi
}
N=$(($(nproc) / ${CPUS_PER_TASK}))
open_sem $N
while read -r KEY OFF LEN; do
PROT_DIR="${SCRATCH_DIR}/${KEY}"
if [[ -d $PROT_DIR ]]
then
continue
fi
mkdir -p $PROT_DIR
run_with_lock task "${KEY}" "${OFF}" "${LEN}"
zip_or_not
done < $TARGET
wait
zip_or_not
wait
......@@ -55,7 +55,7 @@ extra_cuda_flags += cc_flag
setup(
name='openfold',
version='0.1.0',
version='1.0.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind',
author_email='gahdritz@gmail.com',
......
......@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 5e-4))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment