Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
6fbc402e
Unverified
Commit
6fbc402e
authored
Dec 02, 2022
by
LuGY
Committed by
GitHub
Dec 02, 2022
Browse files
refactor chunk (#117)
parent
3b096d67
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
131 additions
and
85 deletions
+131
-85
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+126
-82
fastfold/model/fastnn/template.py
fastfold/model/fastnn/template.py
+5
-3
No files found.
fastfold/model/fastnn/ops.py
View file @
6fbc402e
...
...
@@ -91,13 +91,12 @@ class ChunkTransition(nn.Module):
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
def
forward
(
self
,
src
):
para_dim
=
src
.
shape
[
1
]
chunk_size
=
48
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
out
=
self
.
norm
(
src
)
out
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
out
)))
else
:
chunk_size
=
CHUNK_SIZE
*
48
para_dim
=
src
.
shape
[
1
]
out
=
torch
.
empty_like
(
src
)
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
ax
>
10
:
...
...
@@ -155,11 +154,14 @@ class OutProductMean(nn.Module):
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
M_mask
*
right_act_all
if
CHUNK_SIZE
==
None
:
out
=
torch
.
einsum
(
'bsid, bsje->bijde'
,
left_act
,
right_act_all
)
out
=
rearrange
(
out
,
'b i j d e -> b i j (d e)'
)
out
=
self
.
o_linear
(
out
)
Z
=
out
/
norm
else
:
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act_part
,
right_act_all
)
...
...
@@ -291,11 +293,6 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
if
nonbatched_bias
is
not
None
:
if
nonbatched_bias
[
-
1
]
==
-
1
:
bias
=
nonbatched_bias
[
0
]
...
...
@@ -304,6 +301,31 @@ class SelfAttention(nn.Module):
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
if
CHUNK_SIZE
==
None
:
qkv
=
self
.
to_qkv
(
in_data
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
nonbatched_bias
is
not
None
:
weights
=
fused_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
))
else
:
weights
=
fused_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
=
self
.
o_linear
(
weighted_avg
)
else
:
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
...
...
@@ -981,16 +1003,16 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
)
def
forward
(
self
,
M_raw
,
M_mask
):
para_dim
=
M_raw
.
shape
[
2
]
if
CHUNK_SIZE
is
None
:
chunk_size
=
para_dim
m
=
self
.
layernormM
(
M_raw
.
transpose
(
-
2
,
-
3
))
m
=
self
.
global_attention
(
m
,
M_mask
.
transpose
(
-
1
,
-
2
))
m
=
m
.
transpose
(
-
2
,
-
3
)
M_raw
=
M_raw
+
m
else
:
chunk_size
=
CHUNK_SIZE
para_dim
=
M_raw
.
shape
[
2
]
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
break
m
=
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
m
=
self
.
layernormM
(
m
)
m_mask
=
M_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
...
...
@@ -1109,12 +1131,12 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, no_bins]
d
=
((
d
>
squared_bins
)
*
(
d
<
upper
)).
type
(
x
.
dtype
)
# [*, N, N, C_z]
para_dim
=
d
.
shape
[
1
]
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
d
=
self
.
linear
(
d
)
z
=
d
+
self
.
layer_norm_z
(
z
)
else
:
chunk_size
=
CHUNK_SIZE
*
48
para_dim
=
d
.
shape
[
1
]
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
di
=
self
.
linear
(
d
[
i
:
i
+
chunk_size
,
:,
:])
...
...
@@ -1152,10 +1174,33 @@ class GlobalAttention(nn.Module):
def
forward
(
self
,
m
,
mask
):
if
CHUNK_SIZE
==
None
:
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
q
=
q
*
self
.
scaling
q
=
self
.
to_q
(
q
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
k
,
v
=
self
.
to_kv
(
m
).
chunk
(
2
,
dim
=-
1
)
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
weights
=
fused_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
gate_values
=
self
.
gating_linear
(
m
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
)
)
m
=
self
.
o_linear
(
weighted_avg
)
else
:
para_dim
=
m
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
...
...
@@ -1190,7 +1235,6 @@ class GlobalAttention(nn.Module):
return
m
class
InputEmbedder
(
nn
.
Module
):
"""
Embeds a subset of the input features.
...
...
fastfold/model/fastnn/template.py
View file @
6fbc402e
...
...
@@ -387,9 +387,11 @@ class TemplatePairStack(nn.Module):
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
if
not
self
.
training
:
for
i
in
range
(
0
,
t
.
shape
[
0
]):
t
[
i
]
=
self
.
layer_norm
(
t
[
i
])
else
:
t
=
self
.
layer_norm
(
t
[
i
])
return
t
def
inplace
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment