Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
6fbc402e
Unverified
Commit
6fbc402e
authored
Dec 02, 2022
by
LuGY
Committed by
GitHub
Dec 02, 2022
Browse files
refactor chunk (#117)
parent
3b096d67
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
131 additions
and
85 deletions
+131
-85
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+126
-82
fastfold/model/fastnn/template.py
fastfold/model/fastnn/template.py
+5
-3
No files found.
fastfold/model/fastnn/ops.py
View file @
6fbc402e
...
@@ -91,20 +91,19 @@ class ChunkTransition(nn.Module):
...
@@ -91,20 +91,19 @@ class ChunkTransition(nn.Module):
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
def
forward
(
self
,
src
):
def
forward
(
self
,
src
):
para_dim
=
src
.
shape
[
1
]
chunk_size
=
48
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
out
=
self
.
norm
(
src
)
out
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
out
)))
else
:
else
:
chunk_size
=
CHUNK_SIZE
*
48
chunk_size
=
CHUNK_SIZE
*
48
para_dim
=
src
.
shape
[
1
]
out
=
torch
.
empty_like
(
src
)
out
=
torch
.
empty_like
(
src
)
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
ax
>
10
:
if
DEBUG
and
ax
>
10
:
break
break
x
=
self
.
norm
(
src
[:,
ax
:
ax
+
chunk_size
,
:,
:])
x
=
self
.
norm
(
src
[:,
ax
:
ax
+
chunk_size
,
:,
:])
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
out
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
x
out
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
x
out
.
add_
(
src
)
out
.
add_
(
src
)
return
out
return
out
...
@@ -155,18 +154,21 @@ class OutProductMean(nn.Module):
...
@@ -155,18 +154,21 @@ class OutProductMean(nn.Module):
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
M_mask
*
right_act_all
right_act_all
=
M_mask
*
right_act_all
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
out
=
torch
.
einsum
(
'bsid, bsje->bijde'
,
left_act
,
right_act_all
)
out
=
rearrange
(
out
,
'b i j d e -> b i j (d e)'
)
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
out
=
self
.
o_linear
(
out
)
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
Z
=
out
/
norm
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act_part
,
right_act_all
)
else
:
O
=
rearrange
(
O
,
'b i j d e -> b i j (d e)'
)
para_dim
=
left_act
.
shape
[
2
]
O
=
self
.
o_linear
(
O
)
chunk_size
=
CHUNK_SIZE
norm0
=
norm
[:,
ax
:
ax
+
chunk_size
,
:,
:]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
Z
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
O
/
norm0
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act_part
,
right_act_all
)
O
=
rearrange
(
O
,
'b i j d e -> b i j (d e)'
)
O
=
self
.
o_linear
(
O
)
norm0
=
norm
[:,
ax
:
ax
+
chunk_size
,
:,
:]
Z
[:,
ax
:
ax
+
chunk_size
,
:,
:]
=
O
/
norm0
return
Z
+
Z_raw
return
Z
+
Z_raw
...
@@ -291,11 +293,6 @@ class SelfAttention(nn.Module):
...
@@ -291,11 +293,6 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
"""
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
if
nonbatched_bias
is
not
None
:
if
nonbatched_bias
is
not
None
:
if
nonbatched_bias
[
-
1
]
==
-
1
:
if
nonbatched_bias
[
-
1
]
==
-
1
:
bias
=
nonbatched_bias
[
0
]
bias
=
nonbatched_bias
[
0
]
...
@@ -303,14 +300,9 @@ class SelfAttention(nn.Module):
...
@@ -303,14 +300,9 @@ class SelfAttention(nn.Module):
# logits += nonbatched_bias.unsqueeze(1)
# logits += nonbatched_bias.unsqueeze(1)
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
output
=
[]
if
CHUNK_SIZE
==
None
:
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
qkv
=
self
.
to_qkv
(
in_data
).
chunk
(
3
,
dim
=-
1
)
in_data_part
=
in_data
[:,
ax
:
ax
+
chunk_size
,
:,
:]
mask_part
=
mask
[:,
ax
:
ax
+
chunk_size
,
:]
qkv
=
self
.
to_qkv
(
in_data_part
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
q
=
q
*
self
.
scaling
q
=
q
*
self
.
scaling
...
@@ -318,25 +310,55 @@ class SelfAttention(nn.Module):
...
@@ -318,25 +310,55 @@ class SelfAttention(nn.Module):
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
nonbatched_bias
is
not
None
:
if
nonbatched_bias
is
not
None
:
# logits += bias.unsqueeze(1)
weights
=
fused_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
))
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights
=
fused_softmax
(
logits
,
mask_part
,
bias
.
unsqueeze
(
1
))
else
:
else
:
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
weights
=
fused_softmax
(
logits
,
mask
)
# weights = torch.nn.functional.softmax(logits, -1)
weights
=
fused_softmax
(
logits
,
mask_part
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
if
self
.
gating
:
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data
_part
)
gate_values
=
self
.
gating_linear
(
in_data
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
.
append
(
self
.
o_linear
(
weighted_avg
))
output
=
self
.
o_linear
(
weighted_avg
)
else
:
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
in_data_part
=
in_data
[:,
ax
:
ax
+
chunk_size
,
:,
:]
mask_part
=
mask
[:,
ax
:
ax
+
chunk_size
,
:]
qkv
=
self
.
to_qkv
(
in_data_part
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
output
=
torch
.
cat
(
output
,
dim
=
1
)
if
nonbatched_bias
is
not
None
:
# logits += bias.unsqueeze(1)
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights
=
fused_softmax
(
logits
,
mask_part
,
bias
.
unsqueeze
(
1
))
else
:
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# weights = torch.nn.functional.softmax(logits, -1)
weights
=
fused_softmax
(
logits
,
mask_part
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data_part
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
.
append
(
self
.
o_linear
(
weighted_avg
))
output
=
torch
.
cat
(
output
,
dim
=
1
)
return
output
return
output
...
@@ -981,22 +1003,22 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
...
@@ -981,22 +1003,22 @@ class ChunkMSAColumnGlobalAttention(nn.Module):
)
)
def
forward
(
self
,
M_raw
,
M_mask
):
def
forward
(
self
,
M_raw
,
M_mask
):
para_dim
=
M_raw
.
shape
[
2
]
if
CHUNK_SIZE
is
None
:
if
CHUNK_SIZE
is
None
:
chunk_size
=
para_dim
m
=
self
.
layernormM
(
M_raw
.
transpose
(
-
2
,
-
3
))
m
=
self
.
global_attention
(
m
,
M_mask
.
transpose
(
-
1
,
-
2
))
m
=
m
.
transpose
(
-
2
,
-
3
)
M_raw
=
M_raw
+
m
else
:
else
:
chunk_size
=
CHUNK_SIZE
chunk_size
=
CHUNK_SIZE
para_dim
=
M_raw
.
shape
[
2
]
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
if
DEBUG
and
i
>
10
:
m
=
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
break
m
=
self
.
layernormM
(
m
)
m
=
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:].
transpose
(
-
2
,
-
3
)
m_mask
=
M_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
m
=
self
.
layernormM
(
m
)
m
=
self
.
global_attention
(
m
,
m_mask
)
m_mask
=
M_mask
[:,
:,
i
:
i
+
chunk_size
].
transpose
(
-
1
,
-
2
)
m
=
m
.
transpose
(
-
2
,
-
3
)
m
=
self
.
global_attention
(
m
,
m_mask
)
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:]
+=
m
m
=
m
.
transpose
(
-
2
,
-
3
)
M_raw
[:,
:,
i
:
i
+
chunk_size
,
:]
+=
m
return
M_raw
return
M_raw
...
@@ -1109,16 +1131,16 @@ class RecyclingEmbedder(nn.Module):
...
@@ -1109,16 +1131,16 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, no_bins]
# [*, N, N, no_bins]
d
=
((
d
>
squared_bins
)
*
(
d
<
upper
)).
type
(
x
.
dtype
)
d
=
((
d
>
squared_bins
)
*
(
d
<
upper
)).
type
(
x
.
dtype
)
# [*, N, N, C_z]
para_dim
=
d
.
shape
[
1
]
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
d
=
self
.
linear
(
d
)
z
=
d
+
self
.
layer_norm_z
(
z
)
else
:
else
:
chunk_size
=
CHUNK_SIZE
*
48
chunk_size
=
CHUNK_SIZE
*
48
para_dim
=
d
.
shape
[
1
]
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
for
i
in
range
(
0
,
para_dim
,
chunk_size
):
di
=
self
.
linear
(
d
[
i
:
i
+
chunk_size
,
:,
:])
di
=
self
.
linear
(
d
[
i
:
i
+
chunk_size
,
:,
:])
z
[
i
:
i
+
chunk_size
,
:,
:]
=
di
+
self
.
layer_norm_z
(
z
[
i
:
i
+
chunk_size
,
:,
:])
z
[
i
:
i
+
chunk_size
,
:,
:]
=
di
+
self
.
layer_norm_z
(
z
[
i
:
i
+
chunk_size
,
:,
:])
return
m_update
,
z
return
m_update
,
z
...
@@ -1152,44 +1174,66 @@ class GlobalAttention(nn.Module):
...
@@ -1152,44 +1174,66 @@ class GlobalAttention(nn.Module):
def
forward
(
self
,
m
,
mask
):
def
forward
(
self
,
m
,
mask
):
para_dim
=
m
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
m_part
=
m
[:,
ax
:
ax
+
chunk_size
,
:,
:]
mask_part
=
mask
[:,
ax
:
ax
+
chunk_size
,
:]
q
=
torch
.
sum
(
m_part
*
mask_part
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask_part
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
)
q
=
q
*
self
.
scaling
q
=
q
*
self
.
scaling
q
=
self
.
to_q
(
q
)
q
=
self
.
to_q
(
q
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
k
,
v
=
self
.
to_kv
(
m
_part
).
chunk
(
2
,
dim
=-
1
)
k
,
v
=
self
.
to_kv
(
m
).
chunk
(
2
,
dim
=-
1
)
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
weights
=
fused_softmax
(
logits
,
mask
_part
)
weights
=
fused_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
gate_values
=
self
.
gating_linear
(
m
_part
)
gate_values
=
self
.
gating_linear
(
m
)
weighted_avg
=
bias_sigmod_ele
(
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
)
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
)
)
)
output
.
append
(
self
.
o_linear
(
weighted_avg
))
m
=
self
.
o_linear
(
weighted_avg
)
else
:
para_dim
=
m
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
m
=
torch
.
cat
(
output
,
dim
=
1
)
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
return
m
m_part
=
m
[:,
ax
:
ax
+
chunk_size
,
:,
:]
mask_part
=
mask
[:,
ax
:
ax
+
chunk_size
,
:]
q
=
torch
.
sum
(
m_part
*
mask_part
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask_part
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
q
=
q
*
self
.
scaling
q
=
self
.
to_q
(
q
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
k
,
v
=
self
.
to_kv
(
m_part
).
chunk
(
2
,
dim
=-
1
)
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
weights
=
fused_softmax
(
logits
,
mask_part
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
gate_values
=
self
.
gating_linear
(
m_part
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
)
)
output
.
append
(
self
.
o_linear
(
weighted_avg
))
m
=
torch
.
cat
(
output
,
dim
=
1
)
return
m
class
InputEmbedder
(
nn
.
Module
):
class
InputEmbedder
(
nn
.
Module
):
"""
"""
...
...
fastfold/model/fastnn/template.py
View file @
6fbc402e
...
@@ -387,9 +387,11 @@ class TemplatePairStack(nn.Module):
...
@@ -387,9 +387,11 @@ class TemplatePairStack(nn.Module):
args
=
(
t
,),
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
if
not
self
.
training
:
for
i
in
range
(
0
,
t
.
shape
[
0
]):
for
i
in
range
(
0
,
t
.
shape
[
0
]):
t
[
i
]
=
self
.
layer_norm
(
t
[
i
])
t
[
i
]
=
self
.
layer_norm
(
t
[
i
])
else
:
t
=
self
.
layer_norm
(
t
[
i
])
return
t
return
t
def
inplace
(
def
inplace
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment