Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
b3af1957
Unverified
Commit
b3af1957
authored
Sep 30, 2022
by
oahzxl
Committed by
GitHub
Sep 30, 2022
Browse files
Multimer supports chunk and inplace (#77)
parent
b3d4fcca
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
88 additions
and
48 deletions
+88
-48
environment.yml
environment.yml
+2
-2
fastfold/model/fastnn/blocks.py
fastfold/model/fastnn/blocks.py
+22
-12
fastfold/model/fastnn/msa.py
fastfold/model/fastnn/msa.py
+1
-1
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+14
-11
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+7
-3
fastfold/model/nn/embedders_multimer.py
fastfold/model/nn/embedders_multimer.py
+28
-18
fastfold/model/nn/evoformer.py
fastfold/model/nn/evoformer.py
+8
-0
fastfold/utils/inject_fastnn.py
fastfold/utils/inject_fastnn.py
+5
-1
inference.py
inference.py
+1
-0
No files found.
environment.yml
View file @
b3af1957
...
...
@@ -13,12 +13,12 @@ dependencies:
-
requests==2.26.0
-
scipy==1.7.1
-
tqdm==4.62.2
-
typing-extensions==
3.10.0.2
-
typing-extensions==
4.3.0
-
einops
-
colossalai
-
ray==2.0.0
-
pyarrow
-
pandas
-
--find-links https://release.colossalai.org colossalai
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchvision==0.13.1
...
...
fastfold/model/fastnn/blocks.py
View file @
b3af1957
...
...
@@ -137,12 +137,17 @@ class EvoformerBlock(nn.Module):
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
else
:
z
=
self
.
communication
(
m
,
msa_mask
,
z
)
z_ori
=
z
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
m
=
self
.
msa_stack
(
m
,
z_ori
,
msa_mask
)
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z
=
self
.
communication
.
inplace
(
m
[
0
],
msa_mask
,
z
)
m
[
0
],
work
=
All_to_All_Async
.
apply
(
m
[
0
],
1
,
2
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
m
[
0
]
=
self
.
msa_stack
(
m
[
0
],
z
[
0
],
msa_mask
)
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
if
self
.
last_block
:
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
...
...
@@ -288,12 +293,17 @@ class ExtraMSABlock(nn.Module):
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
else
:
z
=
self
.
communication
(
m
,
msa_mask
,
z
)
z_ori
=
z
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
m
=
self
.
msa_stack
(
m
,
z_ori
,
msa_mask
)
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = [z[0].clone()]
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m = self.msa_stack.inplace(m, z_ori, msa_mask)
z
=
self
.
communication
.
inplace
(
m
[
0
],
msa_mask
,
z
)
m
[
0
],
work
=
All_to_All_Async
.
apply
(
m
[
0
],
1
,
2
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
m
=
self
.
msa_stack
.
inplace
(
m
,
z
,
msa_mask
)
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
if
self
.
last_block
:
...
...
fastfold/model/fastnn/msa.py
View file @
b3af1957
...
...
@@ -172,7 +172,7 @@ class ExtraMSAStack(nn.Module):
def
inplace
(
self
,
node
,
pair
,
node_mask
):
node_mask_row
=
scatter
(
node_mask
,
dim
=
1
)
node
=
self
.
MSARowAttentionWithPairBias
.
inplace
(
node
,
pair
,
node_mask_row
)
node
=
self
.
MSARowAttentionWithPairBias
.
inplace
(
node
,
pair
[
0
]
,
node_mask_row
)
node
[
0
]
=
row_to_col
(
node
[
0
])
node_mask_col
=
scatter
(
node_mask
,
dim
=
2
)
...
...
fastfold/model/fastnn/ops.py
View file @
b3af1957
...
...
@@ -675,17 +675,18 @@ class ChunkTriangleAttentionStartingNode(nn.Module):
def
inplace
(
self
,
Z_raw
,
Z_mask
):
if
CHUNK_SIZE
==
None
:
Z
=
self
.
layernorm1
(
Z_raw
)
Z
=
self
.
layernorm1
(
Z_raw
[
0
]
)
b
=
self
.
linear_b
(
Z
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
Z
=
self
.
attention
(
Z
,
Z_mask
,
(
b
,
work
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
Z_raw
[
0
]
=
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
Z_raw
[
0
]
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
return
Z_raw
chunk_size
=
CHUNK_SIZE
para_dim
=
Z_raw
[
0
].
shape
[
1
]
...
...
@@ -795,7 +796,7 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
if
CHUNK_SIZE
==
None
:
## Input projections
M
=
self
.
layernormM
(
M_raw
)
M
=
self
.
layernormM
(
M_raw
[
0
]
)
Z
=
self
.
layernormZ
(
Z
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
...
...
@@ -803,15 +804,16 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M
=
self
.
attention
(
M
,
M_mask
,
(
b
,
work
))
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:],
device
=
M
.
device
,
dtype
=
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
M_raw
[
0
]
=
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
[
0
],
prob
=
self
.
p_drop
,
training
=
self
.
training
)
return
M_raw
chunk_size
=
CHUNK_SIZE
para_dim_z
=
Z
[
0
]
.
shape
[
1
]
para_dim_z
=
Z
.
shape
[
1
]
para_dim_m
=
M_raw
[
0
].
shape
[
1
]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b
=
torch
.
empty
((
Z
[
0
]
.
shape
[
0
],
Z
[
0
]
.
shape
[
1
],
Z
[
0
]
.
shape
[
2
],
self
.
n_head
),
device
=
Z
[
0
]
.
device
,
dtype
=
Z
[
0
]
.
dtype
)
b
=
torch
.
empty
((
Z
.
shape
[
0
],
Z
.
shape
[
1
],
Z
.
shape
[
2
],
self
.
n_head
),
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
for
i
in
range
(
0
,
para_dim_z
,
chunk_size
):
z
=
self
.
layernormZ
(
Z
[
0
][
:,
i
:
i
+
chunk_size
,
:,
:])
z
=
self
.
layernormZ
(
Z
[:,
i
:
i
+
chunk_size
,
:,
:])
b
[:,
i
:
i
+
chunk_size
,
:,
:]
=
F
.
linear
(
z
,
self
.
linear_b_weights
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
b
=
gather_async_opp
(
b
,
work
,
dim
=
1
)
...
...
@@ -910,7 +912,7 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
def
inplace
(
self
,
Z_raw
,
Z_mask
):
if
CHUNK_SIZE
==
None
:
Z
=
Z_raw
.
transpose
(
-
2
,
-
3
)
Z
=
Z_raw
[
0
]
.
transpose
(
-
2
,
-
3
)
Z_mask
=
Z_mask
.
transpose
(
-
1
,
-
2
)
Z
=
self
.
layernorm1
(
Z
)
...
...
@@ -919,12 +921,13 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
Z
=
self
.
attention
(
Z
,
Z_mask
,
(
b
,
work
))
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
Z_raw
[
0
]
=
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
Z_raw
[
0
]
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
return
Z_raw
para_dim
=
Z_raw
[
0
].
shape
[
2
]
chunk_size
=
CHUNK_SIZE
...
...
fastfold/model/hub/alphafold.py
View file @
b3af1957
...
...
@@ -88,9 +88,11 @@ class AlphaFold(nn.Module):
**
extra_msa_config
[
"extra_msa_embedder"
],
)
self
.
extra_msa_stack
=
ExtraMSAStack
(
is_multimer
=
self
.
globals
.
is_multimer
,
**
extra_msa_config
[
"extra_msa_stack"
],
)
self
.
evoformer
=
EvoformerStack
(
is_multimer
=
self
.
globals
.
is_multimer
,
**
config
[
"evoformer_stack"
],
)
self
.
structure_module
=
StructureModule
(
...
...
@@ -269,6 +271,7 @@ class AlphaFold(nn.Module):
no_batch_dims
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
inplace
=
self
.
globals
.
inplace
)
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
...
...
@@ -302,12 +305,13 @@ class AlphaFold(nn.Module):
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
dim
=-
2
)
del
torsion_angles_mask
else
:
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
template_embeds
[
"template_mask"
]],
dim
=-
2
,
)
del
template_feats
,
template_embeds
,
torsion_angles_mask
del
template_feats
,
template_embeds
# Embed extra MSA features + merge with pairwise embeddings
if
self
.
config
.
extra_msa
.
enabled
:
...
...
@@ -321,7 +325,7 @@ class AlphaFold(nn.Module):
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
# [*, N, N, C_z]
if
not
self
.
globals
.
inplace
or
self
.
globals
.
is_multimer
:
if
not
self
.
globals
.
inplace
:
z
=
self
.
extra_msa_stack
(
extra_msa_feat
,
z
,
...
...
@@ -347,7 +351,7 @@ class AlphaFold(nn.Module):
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if
not
self
.
globals
.
inplace
or
self
.
globals
.
is_multimer
:
if
not
self
.
globals
.
inplace
:
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
...
...
fastfold/model/nn/embedders_multimer.py
View file @
b3af1957
...
...
@@ -254,18 +254,18 @@ class TemplateSingleEmbedderMultimer(nn.Module):
template_mask
=
template_chi_mask
[...,
0
]
template_
activation
s
=
self
.
template_single_embedder
(
template_
feature
s
=
self
.
template_single_embedder
(
template_features
)
template_
activation
s
=
torch
.
nn
.
functional
.
relu
(
template_
activation
s
template_
feature
s
=
torch
.
nn
.
functional
.
relu
(
template_
feature
s
)
template_
activation
s
=
self
.
template_projector
(
template_
activation
s
,
template_
feature
s
=
self
.
template_projector
(
template_
feature
s
,
)
out
[
"template_single_embedding"
]
=
(
template_
activation
s
template_
feature
s
)
out
[
"template_mask"
]
=
template_mask
...
...
@@ -296,6 +296,7 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim
,
chunk_size
,
multichain_mask_2d
,
inplace
):
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
...
...
@@ -307,7 +308,6 @@ class TemplateEmbedderMultimer(nn.Module):
)
single_template_embeds
=
{}
act
=
0.
template_positions
,
pseudo_beta_mask
=
(
single_template_feats
[
"template_pseudo_beta"
],
...
...
@@ -361,17 +361,27 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)
if
not
inplace
:
# [*, S_t, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
self
.
template_pair_stack
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)
else
:
template_embeds
[
"template_pair_embedding"
]
=
[
template_embeds
[
"template_pair_embedding"
]]
# [*, S_t, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
self
.
template_pair_stack
.
inplace
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)[
0
].
to
(
z
.
device
)
# [*, N, N, C_z]
t
=
torch
.
sum
(
t
,
dim
=-
4
)
/
n_templ
t
=
torch
.
nn
.
functional
.
relu
(
t
)
t
=
self
.
linear_t
(
t
)
template_embeds
[
"template_pair_embedding"
]
=
t
template_embeds
[
"template_pair_embedding"
]
=
torch
.
sum
(
template_embeds
[
"template_pair_embedding"
],
dim
=-
4
)
/
n_templ
template_embeds
[
"template_pair_embedding"
]
=
torch
.
nn
.
functional
.
relu
(
template_embeds
[
"template_pair_embedding"
])
template_embeds
[
"template_pair_embedding"
]
=
self
.
linear_t
(
template_embeds
[
"template_pair_embedding"
])
return
template_embeds
fastfold/model/nn/evoformer.py
View file @
b3af1957
...
...
@@ -229,6 +229,7 @@ class EvoformerBlock(nn.Module):
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
is_multimer
:
bool
,
):
super
(
EvoformerBlock
,
self
).
__init__
()
...
...
@@ -268,6 +269,7 @@ class EvoformerBlock(nn.Module):
c_z
,
c_hidden_opm
,
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
m
:
torch
.
Tensor
,
...
...
@@ -315,6 +317,7 @@ class ExtraMSABlock(nn.Module):
inf
:
float
,
eps
:
float
,
ckpt
:
bool
,
is_multimer
:
bool
,
):
super
(
ExtraMSABlock
,
self
).
__init__
()
...
...
@@ -351,6 +354,7 @@ class ExtraMSABlock(nn.Module):
inf
=
inf
,
eps
=
eps
,
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
m
:
torch
.
Tensor
,
...
...
@@ -415,6 +419,7 @@ class EvoformerStack(nn.Module):
inf
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
is_multimer
:
bool
=
False
,
**
kwargs
,
):
"""
...
...
@@ -474,6 +479,7 @@ class EvoformerStack(nn.Module):
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
is_multimer
=
is_multimer
,
)
self
.
blocks
.
append
(
block
)
...
...
@@ -610,6 +616,7 @@ class ExtraMSAStack(nn.Module):
eps
:
float
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
is_multimer
:
bool
=
False
,
**
kwargs
,
):
super
(
ExtraMSAStack
,
self
).
__init__
()
...
...
@@ -632,6 +639,7 @@ class ExtraMSAStack(nn.Module):
inf
=
inf
,
eps
=
eps
,
ckpt
=
ckpt
,
is_multimer
=
is_multimer
,
)
self
.
blocks
.
append
(
block
)
...
...
fastfold/utils/inject_fastnn.py
View file @
b3af1957
...
...
@@ -228,10 +228,12 @@ def inject_evoformer(model):
for
block_id
,
ori_block
in
enumerate
(
model
.
evoformer
.
blocks
):
c_m
=
ori_block
.
msa_att_row
.
c_in
c_z
=
ori_block
.
msa_att_row
.
c_z
is_multimer
=
ori_block
.
is_multimer
fastfold_block
=
EvoformerBlock
(
c_m
=
c_m
,
c_z
=
c_z
,
first_block
=
(
block_id
==
0
),
last_block
=
(
block_id
==
len
(
model
.
evoformer
.
blocks
)
-
1
)
last_block
=
(
block_id
==
len
(
model
.
evoformer
.
blocks
)
-
1
),
is_multimer
=
is_multimer
,
)
copy_evoformer_para
(
fastfold_block
,
ori_block
)
...
...
@@ -249,11 +251,13 @@ def inject_extraMsaBlock(model):
for
block_id
,
ori_block
in
enumerate
(
model
.
extra_msa_stack
.
blocks
):
c_m
=
ori_block
.
msa_att_row
.
c_in
c_z
=
ori_block
.
msa_att_row
.
c_z
is_multimer
=
ori_block
.
is_multimer
new_model_block
=
ExtraMSABlock
(
c_m
=
c_m
,
c_z
=
c_z
,
first_block
=
(
block_id
==
0
),
last_block
=
(
block_id
==
len
(
model
.
extra_msa_stack
.
blocks
)
-
1
),
is_multimer
=
is_multimer
)
copy_extra_msa_para
(
new_model_block
,
ori_block
)
...
...
inference.py
View file @
b3af1957
...
...
@@ -115,6 +115,7 @@ def inference_model(rank, world_size, result_q, batch, args):
if
args
.
chunk_size
:
config
.
globals
.
chunk_size
=
args
.
chunk_size
config
.
globals
.
inplace
=
args
.
inplace
config
.
globals
.
is_multimer
=
args
.
model_preset
==
'multimer'
model
=
AlphaFold
(
config
)
import_jax_weights_
(
model
,
args
.
param_path
,
version
=
args
.
model_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment