"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "b71c6a11fe97878e86dc5a290c3747a6921f79b3"
Commit f707a9ea authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add flat chunk slicing

parent ebcbaa60
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from functools import partial from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
def permute_final_dims(tensor: torch.Tensor, inds: List[int]): def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
...@@ -124,11 +124,177 @@ def _fetch_dims(tree): ...@@ -124,11 +124,177 @@ def _fetch_dims(tree):
return shapes return shapes
def _flat_idx_to_idx(
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = []
for d in reversed(dims):
idx.append(flat_idx % d)
flat_idx = flat_idx // d
return tuple(reversed(idx))
def _get_minimal_slice_set(
start: Sequence[int],
end: Sequence[int],
dims: int,
start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def reduce_edge_list(l):
tally = 1
for i in range(len(l)):
reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally
tally = l[reversed_idx]
if(start_edges is None):
start_edges = [s == 0 for s in start]
reduce_edge_list(start_edges)
if(end_edges is None):
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
reduce_edge_list(end_edges)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if(len(start) == 0):
return [tuple()]
elif(len(start) == 1):
return [(slice(start[0], end[0] + 1),)]
slices = []
path = []
# Dimensions common to start and end can be selected directly
for s,e in zip(start, end):
if(s == e):
path.append(slice(s, s + 1))
else:
break
path = tuple(path)
divergence_idx = len(path)
# start == end, and we're done
if(divergence_idx == len(dims)):
return [tuple(path)]
def upper():
sdi = start[divergence_idx]
return [
path + (slice(sdi, sdi + 1),) + s for s in
_get_minimal_slice_set(
start[divergence_idx + 1:],
[d - 1 for d in dims[divergence_idx + 1:]],
dims[divergence_idx + 1:],
start_edges=start_edges[divergence_idx + 1:],
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
)
]
def lower():
edi = end[divergence_idx]
return [
path + (slice(edi, edi + 1),) + s for s in
_get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1:]],
end[divergence_idx + 1:],
dims[divergence_idx + 1:],
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
end_edges=end_edges[divergence_idx + 1:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif(start_edges[divergence_idx]):
slices.append(
path + (slice(start[divergence_idx], end[divergence_idx]),)
)
slices.extend(lower())
# Analogous to the previous case, but the top is ragged this time
elif(end_edges[divergence_idx]):
slices.extend(upper())
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else:
slices.extend(upper())
middle_ground = end[divergence_idx] - start[divergence_idx]
if(middle_ground > 1):
slices.append(
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
)
slices.extend(lower())
return [tuple(s) for s in slices]
def _chunk_slice(
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
):
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims = t.shape[:no_batch_dims]
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
# _get_minimal_slice_set is inclusive
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
# Get an ordered list of slices to perform
slices = _get_minimal_slice_set(
start_idx,
end_idx,
batch_dims,
)
#
sliced_tensors = [t[s] for s in slices]
return torch.cat(
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
)
def chunk_layer( def chunk_layer(
layer: Callable, layer: Callable,
inputs: Dict[str, Any], inputs: Dict[str, Any],
chunk_size: int, chunk_size: int,
no_batch_dims: int, no_batch_dims: int,
low_mem: bool = False,
) -> Any: ) -> Any:
""" """
Implements the "chunking" procedure described in section 1.11.8. Implements the "chunking" procedure described in section 1.11.8.
...@@ -151,6 +317,10 @@ def chunk_layer( ...@@ -151,6 +317,10 @@ def chunk_layer(
no_batch_dims: no_batch_dims:
How many of the initial dimensions of each input tensor can How many of the initial dimensions of each input tensor can
be considered batch dimensions. be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns: Returns:
The reassembled output of the layer on the inputs. The reassembled output of the layer on the inputs.
""" """
...@@ -162,12 +332,15 @@ def chunk_layer( ...@@ -162,12 +332,15 @@ def chunk_layer(
def _prep_inputs(t): def _prep_inputs(t):
# TODO: make this more memory efficient. This sucks # TODO: make this more memory efficient. This sucks
if(not low_mem):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims: if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:]) t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
t = t.reshape(-1, *t.shape[no_batch_dims:]) t = t.reshape(-1, *t.shape[no_batch_dims:])
else:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t return t
flattened_inputs = tensor_tree_map(_prep_inputs, inputs) prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
flat_batch_dim = 1 flat_batch_dim = 1
for d in orig_batch_dims: for d in orig_batch_dims:
...@@ -179,10 +352,24 @@ def chunk_layer( ...@@ -179,10 +352,24 @@ def chunk_layer(
i = 0 i = 0
out = None out = None
for _ in range(no_chunks): for _ in range(no_chunks):
# Chunk the input # Chunk the input
select_chunk = lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t if(not low_mem):
chunks = tensor_tree_map(select_chunk, flattened_inputs) select_chunk = (
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
)
else:
select_chunk = (
partial(
_chunk_slice,
flat_start=i,
flat_end=min(flat_batch_dim, i + chunk_size),
no_batch_dims=len(orig_batch_dims)
)
)
chunks = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk # Run the layer on the chunk
output_chunk = layer(**chunks) output_chunk = layer(**chunks)
...@@ -214,7 +401,7 @@ def chunk_layer( ...@@ -214,7 +401,7 @@ def chunk_layer(
i += chunk_size i += chunk_size
reshape = lambda t: t.reshape(orig_batch_dims + t.shape[1:]) reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out) out = tensor_tree_map(reshape, out)
return out return out
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
import unittest import unittest
from openfold.utils.affine_utils import T, quat_to_rot from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.tensor_utils import chunk_layer from openfold.utils.tensor_utils import chunk_layer, _chunk_slice
X_90_ROT = torch.tensor( X_90_ROT = torch.tensor(
...@@ -37,7 +37,7 @@ X_NEG_90_ROT = torch.tensor( ...@@ -37,7 +37,7 @@ X_NEG_90_ROT = torch.tensor(
) )
class TestAffineT(unittest.TestCase): class TestUtils(unittest.TestCase):
def test_T_from_3_points_shape(self): def test_T_from_3_points_shape(self):
batch_size = 2 batch_size = 2
n_res = 5 n_res = 5
...@@ -165,3 +165,18 @@ class TestAffineT(unittest.TestCase): ...@@ -165,3 +165,18 @@ class TestAffineT(unittest.TestCase):
self.assertTrue( self.assertTrue(
torch.all(chunked["inner"]["out"] == unchunked["inner"]["out"]) torch.all(chunked["inner"]["out"] == unchunked["inner"]["out"])
) )
def test_chunk_slice_dict(self):
x = torch.rand(3, 4, 3, 5)
x_flat = x.view(-1, 5)
prod = 1
for d in x.shape[:-1]:
prod = prod * d
for i in range(prod):
for j in range(i + 1, prod + 1):
chunked = _chunk_slice(x, i, j, len(x.shape[:-1]))
chunked_flattened = x_flat[i:j]
self.assertTrue(torch.all(chunked == chunked_flattened))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment