Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f707a9ea
Commit
f707a9ea
authored
Dec 02, 2021
by
Gustaf Ahdritz
Browse files
Add flat chunk slicing
parent
ebcbaa60
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
212 additions
and
10 deletions
+212
-10
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+195
-8
tests/test_utils.py
tests/test_utils.py
+17
-2
No files found.
openfold/utils/tensor_utils.py
View file @
f707a9ea
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
from
functools
import
partial
from
functools
import
partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
def
permute_final_dims
(
tensor
:
torch
.
Tensor
,
inds
:
List
[
int
]):
def
permute_final_dims
(
tensor
:
torch
.
Tensor
,
inds
:
List
[
int
]):
...
@@ -124,11 +124,177 @@ def _fetch_dims(tree):
...
@@ -124,11 +124,177 @@ def _fetch_dims(tree):
return
shapes
return
shapes
def
_flat_idx_to_idx
(
flat_idx
:
int
,
dims
:
Tuple
[
int
],
)
->
Tuple
[
int
]:
idx
=
[]
for
d
in
reversed
(
dims
):
idx
.
append
(
flat_idx
%
d
)
flat_idx
=
flat_idx
//
d
return
tuple
(
reversed
(
idx
))
def
_get_minimal_slice_set
(
start
:
Sequence
[
int
],
end
:
Sequence
[
int
],
dims
:
int
,
start_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
end_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
)
->
Sequence
[
Tuple
[
int
]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def
reduce_edge_list
(
l
):
tally
=
1
for
i
in
range
(
len
(
l
)):
reversed_idx
=
-
1
*
(
i
+
1
)
l
[
reversed_idx
]
*=
tally
tally
=
l
[
reversed_idx
]
if
(
start_edges
is
None
):
start_edges
=
[
s
==
0
for
s
in
start
]
reduce_edge_list
(
start_edges
)
if
(
end_edges
is
None
):
end_edges
=
[
e
==
(
d
-
1
)
for
e
,
d
in
zip
(
end
,
dims
)]
reduce_edge_list
(
end_edges
)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if
(
len
(
start
)
==
0
):
return
[
tuple
()]
elif
(
len
(
start
)
==
1
):
return
[(
slice
(
start
[
0
],
end
[
0
]
+
1
),)]
slices
=
[]
path
=
[]
# Dimensions common to start and end can be selected directly
for
s
,
e
in
zip
(
start
,
end
):
if
(
s
==
e
):
path
.
append
(
slice
(
s
,
s
+
1
))
else
:
break
path
=
tuple
(
path
)
divergence_idx
=
len
(
path
)
# start == end, and we're done
if
(
divergence_idx
==
len
(
dims
)):
return
[
tuple
(
path
)]
def
upper
():
sdi
=
start
[
divergence_idx
]
return
[
path
+
(
slice
(
sdi
,
sdi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
start
[
divergence_idx
+
1
:],
[
d
-
1
for
d
in
dims
[
divergence_idx
+
1
:]],
dims
[
divergence_idx
+
1
:],
start_edges
=
start_edges
[
divergence_idx
+
1
:],
end_edges
=
[
1
for
_
in
end_edges
[
divergence_idx
+
1
:]]
)
]
def
lower
():
edi
=
end
[
divergence_idx
]
return
[
path
+
(
slice
(
edi
,
edi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
[
0
for
_
in
start
[
divergence_idx
+
1
:]],
end
[
divergence_idx
+
1
:],
dims
[
divergence_idx
+
1
:],
start_edges
=
[
1
for
_
in
start_edges
[
divergence_idx
+
1
:]],
end_edges
=
end_edges
[
divergence_idx
+
1
:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if
(
start_edges
[
divergence_idx
]
and
end_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]
+
1
),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif
(
start_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
# Analogous to the previous case, but the top is ragged this time
elif
(
end_edges
[
divergence_idx
]):
slices
.
extend
(
upper
())
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]
+
1
),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else
:
slices
.
extend
(
upper
())
middle_ground
=
end
[
divergence_idx
]
-
start
[
divergence_idx
]
if
(
middle_ground
>
1
):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
return
[
tuple
(
s
)
for
s
in
slices
]
def
_chunk_slice
(
t
:
torch
.
Tensor
,
flat_start
:
int
,
flat_end
:
int
,
no_batch_dims
:
int
,
):
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims
=
t
.
shape
[:
no_batch_dims
]
start_idx
=
list
(
_flat_idx_to_idx
(
flat_start
,
batch_dims
))
# _get_minimal_slice_set is inclusive
end_idx
=
list
(
_flat_idx_to_idx
(
flat_end
-
1
,
batch_dims
))
# Get an ordered list of slices to perform
slices
=
_get_minimal_slice_set
(
start_idx
,
end_idx
,
batch_dims
,
)
#
sliced_tensors
=
[
t
[
s
]
for
s
in
slices
]
return
torch
.
cat
(
[
s
.
view
((
-
1
,)
+
t
.
shape
[
no_batch_dims
:])
for
s
in
sliced_tensors
]
)
def
chunk_layer
(
def
chunk_layer
(
layer
:
Callable
,
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
chunk_size
:
int
,
no_batch_dims
:
int
,
no_batch_dims
:
int
,
low_mem
:
bool
=
False
,
)
->
Any
:
)
->
Any
:
"""
"""
Implements the "chunking" procedure described in section 1.11.8.
Implements the "chunking" procedure described in section 1.11.8.
...
@@ -151,6 +317,10 @@ def chunk_layer(
...
@@ -151,6 +317,10 @@ def chunk_layer(
no_batch_dims:
no_batch_dims:
How many of the initial dimensions of each input tensor can
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
Returns:
The reassembled output of the layer on the inputs.
The reassembled output of the layer on the inputs.
"""
"""
...
@@ -162,12 +332,15 @@ def chunk_layer(
...
@@ -162,12 +332,15 @@ def chunk_layer(
def
_prep_inputs
(
t
):
def
_prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
# TODO: make this more memory efficient. This sucks
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
if
(
not
low_mem
):
t
=
t
.
expand
(
*
orig_batch_dims
,
*
t
.
shape
[
no_batch_dims
:])
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
else
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
return
t
return
t
flatten
ed_inputs
=
tensor_tree_map
(
_prep_inputs
,
inputs
)
prepp
ed_inputs
=
tensor_tree_map
(
_prep_inputs
,
inputs
)
flat_batch_dim
=
1
flat_batch_dim
=
1
for
d
in
orig_batch_dims
:
for
d
in
orig_batch_dims
:
...
@@ -179,10 +352,24 @@ def chunk_layer(
...
@@ -179,10 +352,24 @@ def chunk_layer(
i
=
0
i
=
0
out
=
None
out
=
None
for
_
in
range
(
no_chunks
):
for
_
in
range
(
no_chunks
):
# Chunk the input
# Chunk the input
select_chunk
=
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
if
(
not
low_mem
):
chunks
=
tensor_tree_map
(
select_chunk
,
flattened_inputs
)
select_chunk
=
(
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
)
else
:
select_chunk
=
(
partial
(
_chunk_slice
,
flat_start
=
i
,
flat_end
=
min
(
flat_batch_dim
,
i
+
chunk_size
),
no_batch_dims
=
len
(
orig_batch_dims
)
)
)
chunks
=
tensor_tree_map
(
select_chunk
,
prepped_inputs
)
# Run the layer on the chunk
# Run the layer on the chunk
output_chunk
=
layer
(
**
chunks
)
output_chunk
=
layer
(
**
chunks
)
...
@@ -214,7 +401,7 @@ def chunk_layer(
...
@@ -214,7 +401,7 @@ def chunk_layer(
i
+=
chunk_size
i
+=
chunk_size
reshape
=
lambda
t
:
t
.
reshape
(
orig_batch_dims
+
t
.
shape
[
1
:])
reshape
=
lambda
t
:
t
.
view
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
return
out
tests/test_utils.py
View file @
f707a9ea
...
@@ -17,7 +17,7 @@ import torch
...
@@ -17,7 +17,7 @@ import torch
import
unittest
import
unittest
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
chunk_layer
,
_chunk_slice
X_90_ROT
=
torch
.
tensor
(
X_90_ROT
=
torch
.
tensor
(
...
@@ -37,7 +37,7 @@ X_NEG_90_ROT = torch.tensor(
...
@@ -37,7 +37,7 @@ X_NEG_90_ROT = torch.tensor(
)
)
class
Test
AffineT
(
unittest
.
TestCase
):
class
Test
Utils
(
unittest
.
TestCase
):
def
test_T_from_3_points_shape
(
self
):
def
test_T_from_3_points_shape
(
self
):
batch_size
=
2
batch_size
=
2
n_res
=
5
n_res
=
5
...
@@ -165,3 +165,18 @@ class TestAffineT(unittest.TestCase):
...
@@ -165,3 +165,18 @@ class TestAffineT(unittest.TestCase):
self
.
assertTrue
(
self
.
assertTrue
(
torch
.
all
(
chunked
[
"inner"
][
"out"
]
==
unchunked
[
"inner"
][
"out"
])
torch
.
all
(
chunked
[
"inner"
][
"out"
]
==
unchunked
[
"inner"
][
"out"
])
)
)
def
test_chunk_slice_dict
(
self
):
x
=
torch
.
rand
(
3
,
4
,
3
,
5
)
x_flat
=
x
.
view
(
-
1
,
5
)
prod
=
1
for
d
in
x
.
shape
[:
-
1
]:
prod
=
prod
*
d
for
i
in
range
(
prod
):
for
j
in
range
(
i
+
1
,
prod
+
1
):
chunked
=
_chunk_slice
(
x
,
i
,
j
,
len
(
x
.
shape
[:
-
1
]))
chunked_flattened
=
x_flat
[
i
:
j
]
self
.
assertTrue
(
torch
.
all
(
chunked
==
chunked_flattened
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment