Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
96709588
Commit
96709588
authored
Dec 17, 2021
by
Gustaf Ahdritz
Browse files
Add low-memory attention (still needs to be incorporated)
parent
c4d9f57f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
253 additions
and
10 deletions
+253
-10
openfold/model/primitives.py
openfold/model/primitives.py
+176
-2
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+6
-5
scripts/run_unit_tests.sh
scripts/run_unit_tests.sh
+1
-1
tests/compare_utils.py
tests/compare_utils.py
+0
-2
tests/test_primitives.py
tests/test_primitives.py
+70
-0
No files found.
openfold/model/primitives.py
View file @
96709588
...
...
@@ -14,7 +14,7 @@
# limitations under the License.
import
math
from
typing
import
Optional
,
Callable
,
List
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
import
torch
...
...
@@ -24,6 +24,7 @@ from scipy.stats import truncnorm
from
openfold.utils.tensor_utils
import
(
permute_final_dims
,
flatten_final_dims
,
_chunk_slice
,
)
...
...
@@ -217,7 +218,7 @@ class Attention(nn.Module):
self
.
c_hidden
*
self
.
no_heads
,
self
.
c_q
,
init
=
"final"
)
if
self
.
gating
is
not
None
:
if
self
.
gating
:
self
.
linear_g
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
init
=
"gating"
)
...
...
@@ -370,3 +371,176 @@ class GlobalAttention(nn.Module):
m
=
self
.
linear_o
(
o
)
return
m
@
torch
.
jit
.
script
def
_lma
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
q_chunk_size
:
int
,
kv_chunk_size
:
int
):
no_q
,
no_kv
=
q
.
shape
[
-
3
],
k
.
shape
[
-
3
]
# [*, Q, H, C_hidden]
o
=
q
.
new_zeros
(
q
.
shape
)
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
big_bias_chunks
=
[
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
]
maxes
=
[]
weights
=
[]
values
=
[]
for
kv_s
in
range
(
0
,
no_kv
,
kv_chunk_size
):
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
small_bias_chunks
=
[
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
big_bias_chunks
]
a
=
torch
.
einsum
(
"...qhd,...khd->...hqk"
,
q_chunk
,
k_chunk
)
for
b
in
small_bias_chunks
:
a
+=
b
a
=
a
.
transpose
(
-
2
,
-
3
)
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
].
detach
()
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_v
=
torch
.
einsum
(
"...vhf,...qhv->...qhf"
,
v_chunk
,
exp_a
)
maxes
.
append
(
max_a
.
squeeze
(
-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
values
.
append
(
exp_v
)
chunk_max
=
torch
.
stack
(
maxes
,
dim
=-
3
)
chunk_weights
=
torch
.
stack
(
weights
,
dim
=-
3
)
chunk_values
=
torch
.
stack
(
values
,
dim
=-
4
)
global_max
=
torch
.
max
(
chunk_max
,
dim
=-
3
,
keepdim
=
True
)[
0
]
max_diffs
=
torch
.
exp
(
chunk_max
-
global_max
)
chunk_values
*=
max_diffs
.
unsqueeze
(
-
1
)
chunk_weights
*=
max_diffs
all_values
=
torch
.
sum
(
chunk_values
,
dim
=-
4
)
all_weights
=
torch
.
sum
(
chunk_weights
.
unsqueeze
(
-
1
),
dim
=-
4
)
q_chunk_out
=
all_values
/
all_weights
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
=
q_chunk_out
return
o
class
LowMemoryAttention
(
nn
.
Module
):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors. Implements Rabe and Staats'
low-memory self-attention algorithm.
"""
def
__init__
(
self
,
c_q
:
int
,
c_k
:
int
,
c_v
:
int
,
c_hidden
:
int
,
no_heads
:
int
,
gating
:
bool
=
True
,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
chunk_size:
Trades memory for better parallelization. A low value
corresponds to lower memory usage.
"""
super
().
__init__
()
self
.
c_q
=
c_q
self
.
c_k
=
c_k
self
.
c_v
=
c_v
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
gating
=
gating
self
.
linear_q
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_k
=
Linear
(
self
.
c_k
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_v
=
Linear
(
self
.
c_v
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_o
=
Linear
(
self
.
c_hidden
*
self
.
no_heads
,
self
.
c_q
,
init
=
"final"
)
if
self
.
gating
:
self
.
linear_g
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
init
=
"gating"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
q_x
:
torch
.
Tensor
,
k_x
:
torch
.
Tensor
,
v_x
:
torch
.
Tensor
,
q_chunk_size
:
int
,
kv_chunk_size
:
int
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
):
if
(
biases
is
None
):
biases
=
[]
else
:
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
k_x
.
shape
[
-
2
],))
for
b
in
biases
]
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
k_x
)
v
=
self
.
linear_v
(
v_x
)
# [*, Q/K, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
/
math
.
sqrt
(
q
.
shape
[
-
1
])
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
if
self
.
gating
:
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
# [*, Q, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
return
o
openfold/utils/tensor_utils.py
View file @
96709588
...
...
@@ -124,6 +124,7 @@ def _fetch_dims(tree):
return
shapes
@
torch
.
jit
.
ignore
def
_flat_idx_to_idx
(
flat_idx
:
int
,
dims
:
Tuple
[
int
],
...
...
@@ -135,6 +136,8 @@ def _flat_idx_to_idx(
return
tuple
(
reversed
(
idx
))
@
torch
.
jit
.
ignore
def
_get_minimal_slice_set
(
start
:
Sequence
[
int
],
end
:
Sequence
[
int
],
...
...
@@ -252,18 +255,19 @@ def _get_minimal_slice_set(
return
[
tuple
(
s
)
for
s
in
slices
]
@
torch
.
jit
.
ignore
def
_chunk_slice
(
t
:
torch
.
Tensor
,
flat_start
:
int
,
flat_end
:
int
,
no_batch_dims
:
int
,
):
)
->
torch
.
Tensor
:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the reshape call, which can be
but without the need for the
initial
reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
...
...
@@ -281,7 +285,6 @@ def _chunk_slice(
batch_dims
,
)
#
sliced_tensors
=
[
t
[
s
]
for
s
in
slices
]
return
torch
.
cat
(
...
...
@@ -352,7 +355,6 @@ def chunk_layer(
i
=
0
out
=
None
for
_
in
range
(
no_chunks
):
# Chunk the input
if
(
not
low_mem
):
...
...
@@ -382,7 +384,6 @@ def chunk_layer(
# Put the chunk in its pre-allocated space
out_type
=
type
(
output_chunk
)
if
out_type
is
dict
:
def
assign
(
d1
,
d2
):
for
k
,
v
in
d1
.
items
():
if
type
(
v
)
is
dict
:
...
...
scripts/run_unit_tests.sh
View file @
96709588
#!/bin/bash
#
CUDA_VISIBLE_DEVICES="
5
"
CUDA_VISIBLE_DEVICES
=
"
0
"
python3
-m
unittest
"
$@
"
||
\
echo
-e
"
\n
Test(s) failed. Make sure you've installed all Python dependencies."
tests/compare_utils.py
View file @
96709588
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"4,"
import
importlib
import
pkgutil
import
sys
...
...
tests/test_primitives.py
0 → 100644
View file @
96709588
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
numpy
as
np
import
unittest
from
openfold.model.primitives
import
(
Attention
,
LowMemoryAttention
,
)
from
tests.config
import
consts
class
TestLMA
(
unittest
.
TestCase
):
def
test_lma_vs_attention
(
self
):
batch_size
=
consts
.
batch_size
c_hidden
=
32
n
=
2
**
12
no_heads
=
4
q
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
k
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
v
=
torch
.
rand
(
batch_size
,
n
,
c_hidden
).
cuda
()
bias
=
[
torch
.
rand
(
no_heads
,
1
,
n
)]
bias
=
[
b
.
cuda
()
for
b
in
bias
]
gating_fill
=
torch
.
rand
(
c_hidden
*
no_heads
,
c_hidden
)
o_fill
=
torch
.
rand
(
c_hidden
,
c_hidden
*
no_heads
)
lma
=
LowMemoryAttention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
a
=
Attention
(
c_hidden
,
c_hidden
,
c_hidden
,
c_hidden
,
no_heads
).
cuda
()
with
torch
.
no_grad
():
for
n
,
p
in
lma
.
named_parameters
():
attrs
=
n
.
split
(
'.'
)
param
=
a
for
attr
in
attrs
:
param
=
getattr
(
param
,
attr
)
param
.
copy_
(
p
)
for
m
in
[
lma
,
a
]:
m
.
linear_g
.
weight
.
copy_
(
gating_fill
)
m
.
linear_o
.
weight
.
copy_
(
o_fill
)
with
torch
.
no_grad
():
l
=
lma
(
q
,
k
,
v
,
1024
,
4096
,
biases
=
bias
)
real
=
a
(
q
,
k
,
v
,
biases
=
bias
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
l
-
real
))
<
consts
.
eps
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment