Commit f707a9ea authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add flat chunk slicing

parent ebcbaa60
......@@ -16,7 +16,7 @@
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
......@@ -124,11 +124,177 @@ def _fetch_dims(tree):
return shapes
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if(start_edges is None):
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if(end_edges is None):
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if(len(start) == 0):
return [tuple()]
elif(len(start) == 1):
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s,e in zip(start, end):
if(s == e):
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if(divergence_idx == len(dims)):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s for s in
_get_minimal_slice_set(
start[divergence_idx + 1:],
[d - 1 for d in dims[divergence_idx + 1:]],
dims[divergence_idx + 1:],
start_edges=start_edges[divergence_idx + 1:],
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s for s in
_get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1:]],
end[divergence_idx + 1:],
dims[divergence_idx + 1:],
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
end_edges=end_edges[divergence_idx + 1:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif(start_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx]),)
)
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif(end_edges[divergence_idx]):
slices.extend(upper())
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if(middle_ground > 1):
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
)
slices.extend(lower())
return [tuple(s) for s in slices]
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
):
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
#
sliced_tensors = [t[s] for s in slices]
return torch.cat(
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
)
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
low_mem: bool = False,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
......@@ -151,6 +317,10 @@ def chunk_layer(
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
......@@ -162,12 +332,15 @@ def chunk_layer(
def _prep_inputs(t):
# TODO: make this more memory efficient. This sucks
if(not low_mem):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:])
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t
flattened_inputs = tensor_tree_map(_prep_inputs, inputs)
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
flat_batch_dim = 1
for d in orig_batch_dims:
......@@ -179,10 +352,24 @@ def chunk_layer(
i = 0
out = None
for _ in range(no_chunks):
# Chunk the input
select_chunk = lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
chunks = tensor_tree_map(select_chunk, flattened_inputs)
if(not low_mem):
select_chunk = (
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
)
else:
select_chunk = (
partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims)
)
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk
output_chunk = layer(**chunks)
......@@ -214,7 +401,7 @@ def chunk_layer(
i += chunk_size
reshape = lambda t: t.reshape(orig_batch_dims + t.shape[1:])
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
......@@ -17,7 +17,7 @@ import torch
import unittest
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.tensor_utils import chunk_layer
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice
X_90_ROT = torch.tensor(
......@@ -37,7 +37,7 @@ X_NEG_90_ROT = torch.tensor(
)
class TestAffineT(unittest.TestCase):
class TestUtils(unittest.TestCase):
def test_T_from_3_points_shape(self):
batch_size = 2
n_res = 5
......@@ -165,3 +165,18 @@ class TestAffineT(unittest.TestCase):
self.assertTrue(
torch.all(chunked["inner"]["out"] == unchunked["inner"]["out"])
)
def test_chunk_slice_dict(self):
x = torch.rand(3, 4, 3, 5)
x_flat = x.view(-1, 5)
prod = 1
for d in x.shape[:-1]:
prod = prod * d
for i in range(prod):
for j in range(i + 1, prod + 1):
chunked = _chunk_slice(x, i, j, len(x.shape[:-1]))
chunked_flattened = x_flat[i:j]
self.assertTrue(torch.all(chunked == chunked_flattened))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment