Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
665e6c97
Commit
665e6c97
authored
Jul 29, 2022
by
shenggan
Committed by
Shenggan
Jul 29, 2022
Browse files
add chunk for self_att and opm (#38)
parent
48ae1b08
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
21 deletions
+62
-21
fastfold/config.py
fastfold/config.py
+1
-1
fastfold/model/fastnn/__init__.py
fastfold/model/fastnn/__init__.py
+2
-2
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+56
-18
inference.py
inference.py
+3
-0
No files found.
fastfold/config.py
View file @
665e6c97
...
...
@@ -96,7 +96,7 @@ c_t = mlc.FieldReference(64, field_type=int)
c_e
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
c_s
=
mlc
.
FieldReference
(
384
,
field_type
=
int
)
blocks_per_ckpt
=
mlc
.
FieldReference
(
None
,
field_type
=
int
)
chunk_size
=
mlc
.
FieldReference
(
4
,
field_type
=
int
)
chunk_size
=
mlc
.
FieldReference
(
None
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
tm_enabled
=
mlc
.
FieldReference
(
False
,
field_type
=
bool
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
...
...
fastfold/model/fastnn/__init__.py
View file @
665e6c97
from
.msa
import
MSAStack
from
.ops
import
OutProductMean
from
.ops
import
OutProductMean
,
set_chunk_size
from
.triangle
import
PairStack
from
.evoformer
import
Evoformer
__all__
=
[
'MSAStack'
,
'OutProductMean'
,
'PairStack'
,
'Evoformer'
]
__all__
=
[
'MSAStack'
,
'OutProductMean'
,
'PairStack'
,
'Evoformer'
,
'set_chunk_size'
]
fastfold/model/fastnn/ops.py
View file @
665e6c97
...
...
@@ -11,6 +11,14 @@ from fastfold.model.fastnn.kernel import bias_sigmod_ele
from
fastfold.distributed
import
gather
,
scatter
from
fastfold.distributed.comm_async
import
gather_async
,
gather_async_opp
CHUNK_SIZE
=
None
def
set_chunk_size
(
chunk_size
):
global
CHUNK_SIZE
CHUNK_SIZE
=
chunk_size
class
DropoutRowwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
...
...
@@ -81,9 +89,23 @@ class OutProductMean(nn.Module):
right_act_all
=
gather_async_opp
(
right_act_all
,
work
,
dim
=
2
)
right_act_all
=
M_mask
*
right_act_all
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act
,
right_act_all
)
O
=
rearrange
(
O
,
'b i j d e -> b i j (d e)'
)
Z
=
self
.
o_linear
(
O
)
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
out
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
O
=
torch
.
einsum
(
'bsid,bsje->bijde'
,
left_act_part
,
right_act_all
)
O
=
rearrange
(
O
,
'b i j d e -> b i j (d e)'
)
out
.
append
(
self
.
o_linear
(
O
))
Z
=
torch
.
cat
(
out
,
dim
=
1
)
Z
/=
(
1e-3
+
norm
)
...
...
@@ -157,27 +179,43 @@ class SelfAttention(nn.Module):
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
qkv
=
self
.
to_qkv
(
in_data
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
if
nonbatched_bias
is
not
None
:
# logits += nonbatched_bias.unsqueeze(1)
bias
=
gather_async_opp
(
*
nonbatched_bias
,
dim
=
1
)
bias
=
rearrange
(
bias
,
'b q k h -> b h q k'
)
weights
=
mask_bias_softmax
(
logits
,
mask
,
bias
.
unsqueeze
(
1
))
else
:
weights
=
mask_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
in_data_part
=
in_data
[:,
ax
:
ax
+
chunk_size
,
:,
:]
mask_part
=
mask
[:,
ax
:
ax
+
chunk_size
,
:]
qkv
=
self
.
to_qkv
(
in_data_part
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
nonbatched_bias
is
not
None
:
weights
=
mask_bias_softmax
(
logits
,
mask_part
,
bias
.
unsqueeze
(
1
))
else
:
weights
=
mask_softmax
(
logits
,
mask
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data_part
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
.
append
(
self
.
o_linear
(
weighted_avg
))
output
=
self
.
o_linear
(
weighted_avg
)
output
=
torch
.
cat
(
output
,
dim
=
1
)
return
output
inference.py
View file @
665e6c97
...
...
@@ -29,6 +29,7 @@ import fastfold
import
fastfold.relax.relax
as
relax
from
fastfold.common
import
protein
,
residue_constants
from
fastfold.config
import
model_config
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
fastfold.utils
import
inject_fastnn
from
fastfold.utils.import_weights
import
import_jax_weights_
...
...
@@ -89,6 +90,8 @@ def inference_model(rank, world_size, result_q, batch, args):
model
=
model
.
eval
()
model
=
model
.
cuda
()
set_chunk_size
(
model
.
globals
.
chunk_size
)
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment