Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4bd4ad93
Commit
4bd4ad93
authored
Sep 26, 2021
by
Gustaf Ahdritz
Browse files
Add first attempt at scripting attention
parent
dd8e44b3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
18 deletions
+30
-18
openfold/model/msa.py
openfold/model/msa.py
+2
-2
openfold/model/primitives.py
openfold/model/primitives.py
+22
-10
openfold/model/template.py
openfold/model/template.py
+2
-2
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+2
-2
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+2
-2
No files found.
openfold/model/msa.py
View file @
4bd4ad93
...
...
@@ -17,7 +17,7 @@ import math
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
A
ttention
from
openfold.model.primitives
import
Linear
,
scripted_a
ttention
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
...
...
@@ -69,7 +69,7 @@ class MSAAttention(nn.Module):
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
mha
=
A
ttention
(
self
.
mha
=
scripted_a
ttention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
...
...
openfold/model/primitives.py
View file @
4bd4ad93
...
...
@@ -14,7 +14,7 @@
# limitations under the License.
import
math
from
typing
import
Optional
,
Callable
from
typing
import
Optional
,
Callable
,
List
import
numpy
as
np
import
torch
...
...
@@ -212,7 +212,7 @@ class Attention(nn.Module):
self
.
c_hidden
*
self
.
no_heads
,
self
.
c_q
,
init
=
"final"
)
if
(
self
.
gating
):
if
(
self
.
gating
is
not
None
):
self
.
linear_g
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
init
=
"gating"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
...
...
@@ -222,7 +222,7 @@ class Attention(nn.Module):
q_x
:
torch
.
Tensor
,
k_x
:
torch
.
Tensor
,
v_x
:
torch
.
Tensor
,
biases
:
bool
=
None
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -235,20 +235,26 @@ class Attention(nn.Module):
Returns
[*, Q, C_q] attention update
"""
# Flatten batch dims
batch_dims
=
q_x
.
shape
[:
-
2
]
q_x
=
q_x
.
view
((
-
1
,)
+
q_x
.
shape
[
-
2
:])
k_x
=
k_x
.
view
((
-
1
,)
+
k_x
.
shape
[
-
2
:])
v_x
=
v_x
.
view
((
-
1
,)
+
v_x
.
shape
[
-
2
:])
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
k_x
)
v
=
self
.
linear_v
(
v_x
)
# [*, Q/K, H, C_hidden]
q
=
q
.
view
(
*
q
.
shape
[:
-
1
]
,
self
.
no_heads
,
-
1
)
k
=
k
.
view
(
*
k
.
shape
[:
-
1
]
,
self
.
no_heads
,
-
1
)
v
=
v
.
view
(
*
v
.
shape
[:
-
1
]
,
self
.
no_heads
,
-
1
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
# [*, H, Q, K]
a
=
torch
.
matmul
(
permute
_final_dims
(
q
,
1
,
0
,
2
),
# [*, H, Q, C_hidden]
permute
_final_dims
(
k
,
1
,
2
,
0
),
# [*, H, C_hidden, K]
q
.
permute
(
0
,
2
,
1
,
3
),
# [*, H, Q, C_hidden]
k
.
permute
(
0
,
2
,
3
,
1
),
# [*, H, C_hidden, K]
)
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
a
=
a
*
norm
...
...
@@ -260,7 +266,7 @@ class Attention(nn.Module):
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
a
,
permute
_final_dims
(
v
,
1
,
0
,
2
),
# [*, H, V, C_hidden]
v
.
permute
(
0
,
2
,
1
,
3
),
# [*, H, V, C_hidden]
)
# [*, Q, H, C_hidden]
...
...
@@ -268,7 +274,7 @@ class Attention(nn.Module):
if
(
self
.
gating
):
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
g
=
g
.
view
(
*
g
.
shape
[:
-
1
]
,
self
.
no_heads
,
-
1
)
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
o
=
o
*
g
# [*, Q, H * C_hidden]
...
...
@@ -276,5 +282,11 @@ class Attention(nn.Module):
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
# Restore the batch dims
o
=
o
.
reshape
(
batch_dims
+
o
.
shape
[
1
:])
return
o
def
scripted_attention
(
*
args
,
**
kwargs
):
return
torch
.
jit
.
script
(
Attention
(
*
args
,
**
kwargs
))
openfold/model/template.py
View file @
4bd4ad93
...
...
@@ -18,7 +18,7 @@ import math
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
A
ttention
from
openfold.model.primitives
import
Linear
,
scripted_a
ttention
from
openfold.utils.deepspeed
import
checkpoint_blocks
from
openfold.model.dropout
import
(
DropoutRowwise
,
...
...
@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self
.
no_heads
=
no_heads
self
.
chunk_size
=
chunk_size
self
.
mha
=
A
ttention
(
self
.
mha
=
scripted_a
ttention
(
self
.
c_z
,
self
.
c_t
,
self
.
c_t
,
self
.
c_hidden
,
self
.
no_heads
,
gating
=
False
,
...
...
openfold/model/triangular_attention.py
View file @
4bd4ad93
...
...
@@ -18,7 +18,7 @@ import math
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
A
ttention
from
openfold.model.primitives
import
Linear
,
scripted_a
ttention
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
...
...
@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module):
self
.
linear
=
Linear
(
c_in
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
mha
=
A
ttention
(
self
.
mha
=
scripted_a
ttention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
...
...
openfold/utils/tensor_utils.py
View file @
4bd4ad93
...
...
@@ -24,8 +24,8 @@ def permute_final_dims(tensor, *inds):
return
tensor
.
permute
(
*
first_inds
,
*
[
zero_index
+
i
for
i
in
inds
])
def
flatten_final_dims
(
tensor
,
no_dims
):
return
tensor
.
reshape
(
*
tensor
.
shape
[:
-
no_dims
]
,
-
1
)
def
flatten_final_dims
(
tensor
:
torch
.
Tensor
,
no_dims
:
int
):
return
tensor
.
reshape
(
tensor
.
shape
[:
-
no_dims
]
+
(
-
1
,)
)
def
masked_mean
(
mask
,
value
,
dim
,
eps
=
1e-10
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment