Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
b3af1957
"vscode:/vscode.git/clone" did not exist on "9bafef34bde99a6184c4d4c5af8bd434244a5ed9"
Unverified
Commit
b3af1957
authored
Sep 30, 2022
by
oahzxl
Committed by
GitHub
Sep 30, 2022
Browse files
Multimer supports chunk and inplace (#77)
parent
b3d4fcca
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
88 additions
and
48 deletions
+88
-48
environment.yml
environment.yml
+2
-2
fastfold/model/fastnn/blocks.py
fastfold/model/fastnn/blocks.py
+22
-12
fastfold/model/fastnn/msa.py
fastfold/model/fastnn/msa.py
+1
-1
fastfold/model/fastnn/ops.py
fastfold/model/fastnn/ops.py
+14
-11
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+7
-3
fastfold/model/nn/embedders_multimer.py
fastfold/model/nn/embedders_multimer.py
+28
-18
fastfold/model/nn/evoformer.py
fastfold/model/nn/evoformer.py
+8
-0
fastfold/utils/inject_fastnn.py
fastfold/utils/inject_fastnn.py
+5
-1
inference.py
inference.py
+1
-0
No files found.
environment.yml
View file @
b3af1957
...
@@ -13,12 +13,12 @@ dependencies:
...
@@ -13,12 +13,12 @@ dependencies:
-
requests==2.26.0
-
requests==2.26.0
-
scipy==1.7.1
-
scipy==1.7.1
-
tqdm==4.62.2
-
tqdm==4.62.2
-
typing-extensions==
3.10.0.2
-
typing-extensions==
4.3.0
-
einops
-
einops
-
colossalai
-
ray==2.0.0
-
ray==2.0.0
-
pyarrow
-
pyarrow
-
pandas
-
pandas
-
--find-links https://release.colossalai.org colossalai
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torch==1.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchaudio==0.11.1+cu113
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchvision==0.13.1
-
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html torchvision==0.13.1
...
...
fastfold/model/fastnn/blocks.py
View file @
b3af1957
...
@@ -137,12 +137,17 @@ class EvoformerBlock(nn.Module):
...
@@ -137,12 +137,17 @@ class EvoformerBlock(nn.Module):
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
else
:
else
:
z
=
self
.
communication
(
m
,
msa_mask
,
z
)
# z = self.communication.inplace(m[0], msa_mask, z)
z_ori
=
z
# z_ori = z[0].clone()
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
# z = self.pair_stack.inplace(z, pair_mask)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m
=
self
.
msa_stack
(
m
,
z_ori
,
msa_mask
)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z
=
self
.
communication
.
inplace
(
m
[
0
],
msa_mask
,
z
)
m
[
0
],
work
=
All_to_All_Async
.
apply
(
m
[
0
],
1
,
2
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
m
[
0
]
=
self
.
msa_stack
(
m
[
0
],
z
[
0
],
msa_mask
)
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
if
self
.
last_block
:
if
self
.
last_block
:
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
...
@@ -288,12 +293,17 @@ class ExtraMSABlock(nn.Module):
...
@@ -288,12 +293,17 @@ class ExtraMSABlock(nn.Module):
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
else
:
else
:
z
=
self
.
communication
(
m
,
msa_mask
,
z
)
# z = self.communication.inplace(m[0], msa_mask, z)
z_ori
=
z
# z_ori = [z[0].clone()]
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
# z = self.pair_stack.inplace(z, pair_mask)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m
=
self
.
msa_stack
(
m
,
z_ori
,
msa_mask
)
# m = self.msa_stack.inplace(m, z_ori, msa_mask)
z
=
self
.
communication
.
inplace
(
m
[
0
],
msa_mask
,
z
)
m
[
0
],
work
=
All_to_All_Async
.
apply
(
m
[
0
],
1
,
2
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
m
=
self
.
msa_stack
.
inplace
(
m
,
z
,
msa_mask
)
z
=
self
.
pair_stack
.
inplace
(
z
,
pair_mask
)
if
self
.
last_block
:
if
self
.
last_block
:
...
...
fastfold/model/fastnn/msa.py
View file @
b3af1957
...
@@ -172,7 +172,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -172,7 +172,7 @@ class ExtraMSAStack(nn.Module):
def
inplace
(
self
,
node
,
pair
,
node_mask
):
def
inplace
(
self
,
node
,
pair
,
node_mask
):
node_mask_row
=
scatter
(
node_mask
,
dim
=
1
)
node_mask_row
=
scatter
(
node_mask
,
dim
=
1
)
node
=
self
.
MSARowAttentionWithPairBias
.
inplace
(
node
,
pair
,
node_mask_row
)
node
=
self
.
MSARowAttentionWithPairBias
.
inplace
(
node
,
pair
[
0
]
,
node_mask_row
)
node
[
0
]
=
row_to_col
(
node
[
0
])
node
[
0
]
=
row_to_col
(
node
[
0
])
node_mask_col
=
scatter
(
node_mask
,
dim
=
2
)
node_mask_col
=
scatter
(
node_mask
,
dim
=
2
)
...
...
fastfold/model/fastnn/ops.py
View file @
b3af1957
...
@@ -675,17 +675,18 @@ class ChunkTriangleAttentionStartingNode(nn.Module):
...
@@ -675,17 +675,18 @@ class ChunkTriangleAttentionStartingNode(nn.Module):
def
inplace
(
self
,
Z_raw
,
Z_mask
):
def
inplace
(
self
,
Z_raw
,
Z_mask
):
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
Z
=
self
.
layernorm1
(
Z_raw
)
Z
=
self
.
layernorm1
(
Z_raw
[
0
]
)
b
=
self
.
linear_b
(
Z
)
b
=
self
.
linear_b
(
Z
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
Z
=
self
.
attention
(
Z
,
Z_mask
,
(
b
,
work
))
Z
=
self
.
attention
(
Z
,
Z_mask
,
(
b
,
work
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
Z_raw
[
0
]
=
bias_dropout_add
(
Z
,
self
.
out_bias
,
self
.
out_bias
,
dropout_mask
,
dropout_mask
,
Z_raw
,
Z_raw
[
0
]
,
prob
=
self
.
p_drop
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
training
=
self
.
training
)
return
Z_raw
chunk_size
=
CHUNK_SIZE
chunk_size
=
CHUNK_SIZE
para_dim
=
Z_raw
[
0
].
shape
[
1
]
para_dim
=
Z_raw
[
0
].
shape
[
1
]
...
@@ -795,7 +796,7 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
...
@@ -795,7 +796,7 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
## Input projections
## Input projections
M
=
self
.
layernormM
(
M_raw
)
M
=
self
.
layernormM
(
M_raw
[
0
]
)
Z
=
self
.
layernormZ
(
Z
)
Z
=
self
.
layernormZ
(
Z
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
...
@@ -803,15 +804,16 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
...
@@ -803,15 +804,16 @@ class ChunkMSARowAttentionWithPairBias(nn.Module):
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
# padding_bias = (1e9 * (M_mask - 1.))[:, :, None, None, :]
M
=
self
.
attention
(
M
,
M_mask
,
(
b
,
work
))
M
=
self
.
attention
(
M
,
M_mask
,
(
b
,
work
))
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:],
device
=
M
.
device
,
dtype
=
M
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:],
device
=
M
.
device
,
dtype
=
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
M_raw
[
0
]
=
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
[
0
],
prob
=
self
.
p_drop
,
training
=
self
.
training
)
return
M_raw
chunk_size
=
CHUNK_SIZE
chunk_size
=
CHUNK_SIZE
para_dim_z
=
Z
[
0
]
.
shape
[
1
]
para_dim_z
=
Z
.
shape
[
1
]
para_dim_m
=
M_raw
[
0
].
shape
[
1
]
para_dim_m
=
M_raw
[
0
].
shape
[
1
]
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
# z is big, but b is small. So we compute z in chunk to get b, and recompute z in chunk later instead of storing it
b
=
torch
.
empty
((
Z
[
0
]
.
shape
[
0
],
Z
[
0
]
.
shape
[
1
],
Z
[
0
]
.
shape
[
2
],
self
.
n_head
),
device
=
Z
[
0
]
.
device
,
dtype
=
Z
[
0
]
.
dtype
)
b
=
torch
.
empty
((
Z
.
shape
[
0
],
Z
.
shape
[
1
],
Z
.
shape
[
2
],
self
.
n_head
),
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
for
i
in
range
(
0
,
para_dim_z
,
chunk_size
):
for
i
in
range
(
0
,
para_dim_z
,
chunk_size
):
z
=
self
.
layernormZ
(
Z
[
0
][
:,
i
:
i
+
chunk_size
,
:,
:])
z
=
self
.
layernormZ
(
Z
[:,
i
:
i
+
chunk_size
,
:,
:])
b
[:,
i
:
i
+
chunk_size
,
:,
:]
=
F
.
linear
(
z
,
self
.
linear_b_weights
)
b
[:,
i
:
i
+
chunk_size
,
:,
:]
=
F
.
linear
(
z
,
self
.
linear_b_weights
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
b
,
work
=
gather_async
(
b
,
dim
=
1
)
b
=
gather_async_opp
(
b
,
work
,
dim
=
1
)
b
=
gather_async_opp
(
b
,
work
,
dim
=
1
)
...
@@ -910,7 +912,7 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
...
@@ -910,7 +912,7 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
def
inplace
(
self
,
Z_raw
,
Z_mask
):
def
inplace
(
self
,
Z_raw
,
Z_mask
):
if
CHUNK_SIZE
==
None
:
if
CHUNK_SIZE
==
None
:
Z
=
Z_raw
.
transpose
(
-
2
,
-
3
)
Z
=
Z_raw
[
0
]
.
transpose
(
-
2
,
-
3
)
Z_mask
=
Z_mask
.
transpose
(
-
1
,
-
2
)
Z_mask
=
Z_mask
.
transpose
(
-
1
,
-
2
)
Z
=
self
.
layernorm1
(
Z
)
Z
=
self
.
layernorm1
(
Z
)
...
@@ -919,12 +921,13 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
...
@@ -919,12 +921,13 @@ class ChunkTriangleAttentionEndingNode(nn.Module):
Z
=
self
.
attention
(
Z
,
Z_mask
,
(
b
,
work
))
Z
=
self
.
attention
(
Z
,
Z_mask
,
(
b
,
work
))
Z
=
Z
.
transpose
(
-
2
,
-
3
)
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
Z_raw
[
0
]
=
bias_dropout_add
(
Z
,
self
.
out_bias
,
self
.
out_bias
,
dropout_mask
,
dropout_mask
,
Z_raw
,
Z_raw
[
0
]
,
prob
=
self
.
p_drop
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
training
=
self
.
training
)
return
Z_raw
para_dim
=
Z_raw
[
0
].
shape
[
2
]
para_dim
=
Z_raw
[
0
].
shape
[
2
]
chunk_size
=
CHUNK_SIZE
chunk_size
=
CHUNK_SIZE
...
...
fastfold/model/hub/alphafold.py
View file @
b3af1957
...
@@ -88,9 +88,11 @@ class AlphaFold(nn.Module):
...
@@ -88,9 +88,11 @@ class AlphaFold(nn.Module):
**
extra_msa_config
[
"extra_msa_embedder"
],
**
extra_msa_config
[
"extra_msa_embedder"
],
)
)
self
.
extra_msa_stack
=
ExtraMSAStack
(
self
.
extra_msa_stack
=
ExtraMSAStack
(
is_multimer
=
self
.
globals
.
is_multimer
,
**
extra_msa_config
[
"extra_msa_stack"
],
**
extra_msa_config
[
"extra_msa_stack"
],
)
)
self
.
evoformer
=
EvoformerStack
(
self
.
evoformer
=
EvoformerStack
(
is_multimer
=
self
.
globals
.
is_multimer
,
**
config
[
"evoformer_stack"
],
**
config
[
"evoformer_stack"
],
)
)
self
.
structure_module
=
StructureModule
(
self
.
structure_module
=
StructureModule
(
...
@@ -269,6 +271,7 @@ class AlphaFold(nn.Module):
...
@@ -269,6 +271,7 @@ class AlphaFold(nn.Module):
no_batch_dims
,
no_batch_dims
,
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
multichain_mask_2d
=
multichain_mask_2d
,
inplace
=
self
.
globals
.
inplace
)
)
feats
[
"template_torsion_angles_mask"
]
=
(
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
template_embeds
[
"template_mask"
]
...
@@ -302,12 +305,13 @@ class AlphaFold(nn.Module):
...
@@ -302,12 +305,13 @@ class AlphaFold(nn.Module):
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
dim
=-
2
dim
=-
2
)
)
del
torsion_angles_mask
else
:
else
:
msa_mask
=
torch
.
cat
(
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
template_embeds
[
"template_mask"
]],
[
feats
[
"msa_mask"
],
template_embeds
[
"template_mask"
]],
dim
=-
2
,
dim
=-
2
,
)
)
del
template_feats
,
template_embeds
,
torsion_angles_mask
del
template_feats
,
template_embeds
# Embed extra MSA features + merge with pairwise embeddings
# Embed extra MSA features + merge with pairwise embeddings
if
self
.
config
.
extra_msa
.
enabled
:
if
self
.
config
.
extra_msa
.
enabled
:
...
@@ -321,7 +325,7 @@ class AlphaFold(nn.Module):
...
@@ -321,7 +325,7 @@ class AlphaFold(nn.Module):
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
# [*, N, N, C_z]
# [*, N, N, C_z]
if
not
self
.
globals
.
inplace
or
self
.
globals
.
is_multimer
:
if
not
self
.
globals
.
inplace
:
z
=
self
.
extra_msa_stack
(
z
=
self
.
extra_msa_stack
(
extra_msa_feat
,
extra_msa_feat
,
z
,
z
,
...
@@ -347,7 +351,7 @@ class AlphaFold(nn.Module):
...
@@ -347,7 +351,7 @@ class AlphaFold(nn.Module):
# m: [*, S, N, C_m]
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
# s: [*, N, C_s]
if
not
self
.
globals
.
inplace
or
self
.
globals
.
is_multimer
:
if
not
self
.
globals
.
inplace
:
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
s
=
self
.
evoformer
(
m
,
m
,
z
,
z
,
...
...
fastfold/model/nn/embedders_multimer.py
View file @
b3af1957
...
@@ -254,18 +254,18 @@ class TemplateSingleEmbedderMultimer(nn.Module):
...
@@ -254,18 +254,18 @@ class TemplateSingleEmbedderMultimer(nn.Module):
template_mask
=
template_chi_mask
[...,
0
]
template_mask
=
template_chi_mask
[...,
0
]
template_
activation
s
=
self
.
template_single_embedder
(
template_
feature
s
=
self
.
template_single_embedder
(
template_features
template_features
)
)
template_
activation
s
=
torch
.
nn
.
functional
.
relu
(
template_
feature
s
=
torch
.
nn
.
functional
.
relu
(
template_
activation
s
template_
feature
s
)
)
template_
activation
s
=
self
.
template_projector
(
template_
feature
s
=
self
.
template_projector
(
template_
activation
s
,
template_
feature
s
,
)
)
out
[
"template_single_embedding"
]
=
(
out
[
"template_single_embedding"
]
=
(
template_
activation
s
template_
feature
s
)
)
out
[
"template_mask"
]
=
template_mask
out
[
"template_mask"
]
=
template_mask
...
@@ -296,6 +296,7 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -296,6 +296,7 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim
,
templ_dim
,
chunk_size
,
chunk_size
,
multichain_mask_2d
,
multichain_mask_2d
,
inplace
):
):
template_embeds
=
[]
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
...
@@ -307,7 +308,6 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -307,7 +308,6 @@ class TemplateEmbedderMultimer(nn.Module):
)
)
single_template_embeds
=
{}
single_template_embeds
=
{}
act
=
0.
template_positions
,
pseudo_beta_mask
=
(
template_positions
,
pseudo_beta_mask
=
(
single_template_feats
[
"template_pseudo_beta"
],
single_template_feats
[
"template_pseudo_beta"
],
...
@@ -361,17 +361,27 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -361,17 +361,27 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds
,
template_embeds
,
)
)
# [*, S_t, N, N, C_z]
if
not
inplace
:
t
=
self
.
template_pair_stack
(
# [*, S_t, N, N, C_z]
template_embeds
[
"template_pair_embedding"
],
template_embeds
[
"template_pair_embedding"
]
=
self
.
template_pair_stack
(
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
template_embeds
[
"template_pair_embedding"
],
chunk_size
=
chunk_size
,
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
_mask_trans
=
False
,
chunk_size
=
chunk_size
,
)
_mask_trans
=
False
,
)
else
:
template_embeds
[
"template_pair_embedding"
]
=
[
template_embeds
[
"template_pair_embedding"
]]
# [*, S_t, N, N, C_z]
template_embeds
[
"template_pair_embedding"
]
=
self
.
template_pair_stack
.
inplace
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)[
0
].
to
(
z
.
device
)
# [*, N, N, C_z]
# [*, N, N, C_z]
t
=
torch
.
sum
(
t
,
dim
=-
4
)
/
n_templ
template_embeds
[
"template_pair_embedding"
]
=
torch
.
sum
(
template_embeds
[
"template_pair_embedding"
],
dim
=-
4
)
/
n_templ
t
=
torch
.
nn
.
functional
.
relu
(
t
)
template_embeds
[
"template_pair_embedding"
]
=
torch
.
nn
.
functional
.
relu
(
template_embeds
[
"template_pair_embedding"
])
t
=
self
.
linear_t
(
t
)
template_embeds
[
"template_pair_embedding"
]
=
self
.
linear_t
(
template_embeds
[
"template_pair_embedding"
])
template_embeds
[
"template_pair_embedding"
]
=
t
return
template_embeds
return
template_embeds
fastfold/model/nn/evoformer.py
View file @
b3af1957
...
@@ -229,6 +229,7 @@ class EvoformerBlock(nn.Module):
...
@@ -229,6 +229,7 @@ class EvoformerBlock(nn.Module):
pair_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
is_multimer
:
bool
,
):
):
super
(
EvoformerBlock
,
self
).
__init__
()
super
(
EvoformerBlock
,
self
).
__init__
()
...
@@ -268,6 +269,7 @@ class EvoformerBlock(nn.Module):
...
@@ -268,6 +269,7 @@ class EvoformerBlock(nn.Module):
c_z
,
c_z
,
c_hidden_opm
,
c_hidden_opm
,
)
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
...
@@ -315,6 +317,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -315,6 +317,7 @@ class ExtraMSABlock(nn.Module):
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
ckpt
:
bool
,
ckpt
:
bool
,
is_multimer
:
bool
,
):
):
super
(
ExtraMSABlock
,
self
).
__init__
()
super
(
ExtraMSABlock
,
self
).
__init__
()
...
@@ -351,6 +354,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -351,6 +354,7 @@ class ExtraMSABlock(nn.Module):
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
)
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
...
@@ -415,6 +419,7 @@ class EvoformerStack(nn.Module):
...
@@ -415,6 +419,7 @@ class EvoformerStack(nn.Module):
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
is_multimer
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -474,6 +479,7 @@ class EvoformerStack(nn.Module):
...
@@ -474,6 +479,7 @@ class EvoformerStack(nn.Module):
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
is_multimer
=
is_multimer
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
...
@@ -610,6 +616,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -610,6 +616,7 @@ class ExtraMSAStack(nn.Module):
eps
:
float
,
eps
:
float
,
ckpt
:
bool
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
is_multimer
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
super
(
ExtraMSAStack
,
self
).
__init__
()
super
(
ExtraMSAStack
,
self
).
__init__
()
...
@@ -632,6 +639,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -632,6 +639,7 @@ class ExtraMSAStack(nn.Module):
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
ckpt
=
ckpt
,
ckpt
=
ckpt
,
is_multimer
=
is_multimer
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
...
...
fastfold/utils/inject_fastnn.py
View file @
b3af1957
...
@@ -228,10 +228,12 @@ def inject_evoformer(model):
...
@@ -228,10 +228,12 @@ def inject_evoformer(model):
for
block_id
,
ori_block
in
enumerate
(
model
.
evoformer
.
blocks
):
for
block_id
,
ori_block
in
enumerate
(
model
.
evoformer
.
blocks
):
c_m
=
ori_block
.
msa_att_row
.
c_in
c_m
=
ori_block
.
msa_att_row
.
c_in
c_z
=
ori_block
.
msa_att_row
.
c_z
c_z
=
ori_block
.
msa_att_row
.
c_z
is_multimer
=
ori_block
.
is_multimer
fastfold_block
=
EvoformerBlock
(
c_m
=
c_m
,
fastfold_block
=
EvoformerBlock
(
c_m
=
c_m
,
c_z
=
c_z
,
c_z
=
c_z
,
first_block
=
(
block_id
==
0
),
first_block
=
(
block_id
==
0
),
last_block
=
(
block_id
==
len
(
model
.
evoformer
.
blocks
)
-
1
)
last_block
=
(
block_id
==
len
(
model
.
evoformer
.
blocks
)
-
1
),
is_multimer
=
is_multimer
,
)
)
copy_evoformer_para
(
fastfold_block
,
ori_block
)
copy_evoformer_para
(
fastfold_block
,
ori_block
)
...
@@ -249,11 +251,13 @@ def inject_extraMsaBlock(model):
...
@@ -249,11 +251,13 @@ def inject_extraMsaBlock(model):
for
block_id
,
ori_block
in
enumerate
(
model
.
extra_msa_stack
.
blocks
):
for
block_id
,
ori_block
in
enumerate
(
model
.
extra_msa_stack
.
blocks
):
c_m
=
ori_block
.
msa_att_row
.
c_in
c_m
=
ori_block
.
msa_att_row
.
c_in
c_z
=
ori_block
.
msa_att_row
.
c_z
c_z
=
ori_block
.
msa_att_row
.
c_z
is_multimer
=
ori_block
.
is_multimer
new_model_block
=
ExtraMSABlock
(
new_model_block
=
ExtraMSABlock
(
c_m
=
c_m
,
c_m
=
c_m
,
c_z
=
c_z
,
c_z
=
c_z
,
first_block
=
(
block_id
==
0
),
first_block
=
(
block_id
==
0
),
last_block
=
(
block_id
==
len
(
model
.
extra_msa_stack
.
blocks
)
-
1
),
last_block
=
(
block_id
==
len
(
model
.
extra_msa_stack
.
blocks
)
-
1
),
is_multimer
=
is_multimer
)
)
copy_extra_msa_para
(
new_model_block
,
ori_block
)
copy_extra_msa_para
(
new_model_block
,
ori_block
)
...
...
inference.py
View file @
b3af1957
...
@@ -115,6 +115,7 @@ def inference_model(rank, world_size, result_q, batch, args):
...
@@ -115,6 +115,7 @@ def inference_model(rank, world_size, result_q, batch, args):
if
args
.
chunk_size
:
if
args
.
chunk_size
:
config
.
globals
.
chunk_size
=
args
.
chunk_size
config
.
globals
.
chunk_size
=
args
.
chunk_size
config
.
globals
.
inplace
=
args
.
inplace
config
.
globals
.
inplace
=
args
.
inplace
config
.
globals
.
is_multimer
=
args
.
model_preset
==
'multimer'
model
=
AlphaFold
(
config
)
model
=
AlphaFold
(
config
)
import_jax_weights_
(
model
,
args
.
param_path
,
version
=
args
.
model_name
)
import_jax_weights_
(
model
,
args
.
param_path
,
version
=
args
.
model_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment