Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
6d8b97ec
"tests/vscode:/vscode.git/clone" did not exist on "bb1b76d3bf9ef78a827086d1b9449975237ecbac"
Unverified
Commit
6d8b97ec
authored
Aug 31, 2022
by
Fazzie-Maqianli
Committed by
GitHub
Aug 31, 2022
Browse files
support multimer (#54)
parent
1efccb6c
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
519 additions
and
83 deletions
+519
-83
README.md
README.md
+1
-1
fastfold/config.py
fastfold/config.py
+93
-1
fastfold/data/tools/hmmsearch.py
fastfold/data/tools/hmmsearch.py
+148
-0
fastfold/distributed/comm.py
fastfold/distributed/comm.py
+1
-1
fastfold/model/fastnn/blocks.py
fastfold/model/fastnn/blocks.py
+15
-8
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+9
-3
fastfold/model/nn/embedders.py
fastfold/model/nn/embedders.py
+141
-1
fastfold/model/nn/template.py
fastfold/model/nn/template.py
+56
-30
fastfold/workflow/template/fastfold_data_workflow.py
fastfold/workflow/template/fastfold_data_workflow.py
+2
-0
inference.py
inference.py
+53
-38
No files found.
README.md
View file @
6d8b97ec
...
@@ -59,7 +59,7 @@ Run the following command to build a docker image from Dockerfile provided.
...
@@ -59,7 +59,7 @@ Run the following command to build a docker image from Dockerfile provided.
```
shell
```
shell
cd
ColossalAI
cd
ColossalAI
docker build
-t
f
astfold ./docker
docker build
-t
F
astfold ./docker
```
```
Run the following command to start the docker container in interactive mode.
Run the following command to start the docker container in interactive mode.
...
...
fastfold/config.py
View file @
6d8b97ec
...
@@ -74,6 +74,24 @@ def model_config(name, train=False, low_prec=False):
...
@@ -74,6 +74,24 @@ def model_config(name, train=False, low_prec=False):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
name
==
"relax"
:
pass
elif
"multimer"
in
name
:
c
.
globals
.
is_multimer
=
True
c
.
data
.
predict
.
max_msa_clusters
=
252
# 128 for monomer
c
.
model
.
structure_module
.
trans_scale_factor
=
20
# 10 for monomer
for
k
,
v
in
multimer_model_config_update
.
items
():
c
.
model
[
k
]
=
v
c
.
data
.
common
.
unsupervised_features
.
extend
(
[
"msa_mask"
,
"seq_mask"
,
"asym_id"
,
"entity_id"
,
"sym_id"
,
]
)
else
:
else
:
raise
ValueError
(
"Invalid model name"
)
raise
ValueError
(
"Invalid model name"
)
...
@@ -275,6 +293,7 @@ config = mlc.ConfigDict(
...
@@ -275,6 +293,7 @@ config = mlc.ConfigDict(
"c_e"
:
c_e
,
"c_e"
:
c_e
,
"c_s"
:
c_s
,
"c_s"
:
c_s
,
"eps"
:
eps
,
"eps"
:
eps
,
"is_multimer"
:
False
,
},
},
"model"
:
{
"model"
:
{
"_mask_trans"
:
False
,
"_mask_trans"
:
False
,
...
@@ -495,3 +514,76 @@ config = mlc.ConfigDict(
...
@@ -495,3 +514,76 @@ config = mlc.ConfigDict(
"ema"
:
{
"decay"
:
0.999
},
"ema"
:
{
"decay"
:
0.999
},
}
}
)
)
multimer_model_config_update
=
{
"input_embedder"
:
{
"tf_dim"
:
21
,
"msa_dim"
:
49
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
"use_chain_relative"
:
True
,
},
"template"
:
{
"distogram"
:
{
"min_bin"
:
3.25
,
"max_bin"
:
50.75
,
"no_bins"
:
39
,
},
"template_pair_embedder"
:
{
"c_z"
:
c_z
,
"c_out"
:
64
,
"c_dgram"
:
39
,
"c_aatype"
:
22
,
},
"template_single_embedder"
:
{
"c_in"
:
34
,
"c_m"
:
c_m
,
},
"template_pair_stack"
:
{
"c_t"
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att"
:
16
,
"c_hidden_tri_mul"
:
64
,
"no_blocks"
:
2
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
},
"c_t"
:
c_t
,
"c_z"
:
c_z
,
"inf"
:
1e5
,
# 1e9,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
},
"heads"
:
{
"lddt"
:
{
"no_bins"
:
50
,
"c_in"
:
c_s
,
"c_hidden"
:
128
,
},
"distogram"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
},
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
tm_enabled
,
},
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_out"
:
22
,
},
"experimentally_resolved"
:
{
"c_s"
:
c_s
,
"c_out"
:
37
,
},
},
}
fastfold/data/tools/hmmsearch.py
0 → 100644
View file @
6d8b97ec
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import
os
import
subprocess
import
logging
from
typing
import
Optional
,
Sequence
from
fastfold.data
import
parsers
from
fastfold.data.tools
import
hmmbuild
from
fastfold.utils
import
general_utils
as
utils
class
Hmmsearch
(
object
):
"""Python wrapper of the hmmsearch binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
,
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
hmmbuild_runner
=
hmmbuild
.
Hmmbuild
(
binary_path
=
hmmbuild_binary_path
)
self
.
database_path
=
database_path
if
flags
is
None
:
# Default hmmsearch run settings.
flags
=
[
"--F1"
,
"0.1"
,
"--F2"
,
"0.1"
,
"--F3"
,
"0.1"
,
"--incE"
,
"100"
,
"-E"
,
"100"
,
"--domE"
,
"100"
,
"--incdomE"
,
"100"
,
]
self
.
flags
=
flags
if
not
os
.
path
.
exists
(
self
.
database_path
):
logging
.
error
(
"Could not find hmmsearch database %s"
,
database_path
)
raise
ValueError
(
f
"Could not find hmmsearch database
{
database_path
}
"
)
@
property
def
output_format
(
self
)
->
str
:
return
"sto"
@
property
def
input_format
(
self
)
->
str
:
return
"sto"
def
query
(
self
,
msa_sto
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
model_construction
=
"hand"
)
return
self
.
query_with_hmm
(
hmm
,
output_dir
)
def
query_with_hmm
(
self
,
hmm
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.hmm"
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
out_path
=
os
.
path
.
join
(
output_dir
,
"hmm_output.sto"
)
with
open
(
hmm_input_path
,
"w"
)
as
f
:
f
.
write
(
hmm
)
cmd
=
[
self
.
binary_path
,
"--noali"
,
# Don't include the alignment in stdout.
"--cpu"
,
"8"
,
]
# If adding flags, we have to do so before the output and input:
if
self
.
flags
:
cmd
.
extend
(
self
.
flags
)
cmd
.
extend
(
[
"-A"
,
out_path
,
hmm_input_path
,
self
.
database_path
,
]
)
logging
.
info
(
"Launching sub-process %s"
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
"hmmsearch (
{
os
.
path
.
basename
(
self
.
database_path
)
}
) query"
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
"hmmsearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
%
(
stdout
.
decode
(
"utf-8"
),
stderr
.
decode
(
"utf-8"
))
)
with
open
(
out_path
)
as
f
:
out_msa
=
f
.
read
()
return
out_msa
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
template_hits
=
parsers
.
parse_hmmsearch_sto
(
output_string
,
input_sequence
,
)
return
template_hits
fastfold/distributed/comm.py
View file @
6d8b97ec
...
@@ -43,7 +43,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
...
@@ -43,7 +43,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
return
tensor
if
dim
==
1
:
if
dim
==
1
and
list
(
tensor
.
shape
)[
0
]
==
1
:
output_shape
=
list
(
tensor
.
shape
)
output_shape
=
list
(
tensor
.
shape
)
output_shape
[
1
]
*=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
output_shape
[
1
]
*=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
...
...
fastfold/model/fastnn/blocks.py
View file @
6d8b97ec
...
@@ -32,7 +32,7 @@ from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_O
...
@@ -32,7 +32,7 @@ from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_O
class
EvoformerBlock
(
nn
.
Module
):
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
first_block
:
bool
,
last_block
:
bool
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
first_block
:
bool
,
last_block
:
bool
,
is_multimer
:
bool
=
False
):
super
(
EvoformerBlock
,
self
).
__init__
()
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
first_block
=
first_block
self
.
first_block
=
first_block
...
@@ -41,6 +41,7 @@ class EvoformerBlock(nn.Module):
...
@@ -41,6 +41,7 @@ class EvoformerBlock(nn.Module):
self
.
msa_stack
=
MSAStack
(
c_m
,
c_z
,
p_drop
=
0.15
)
self
.
msa_stack
=
MSAStack
(
c_m
,
c_z
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
c_m
,
n_feat_out
=
c_z
,
n_feat_proj
=
32
)
self
.
communication
=
OutProductMean
(
n_feat
=
c_m
,
n_feat_out
=
c_z
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
c_z
)
self
.
pair_stack
=
PairStack
(
d_pair
=
c_z
)
self
.
is_multimer
=
is_multimer
def
forward
(
def
forward
(
self
,
self
,
...
@@ -73,12 +74,19 @@ class EvoformerBlock(nn.Module):
...
@@ -73,12 +74,19 @@ class EvoformerBlock(nn.Module):
msa_mask
=
torch
.
nn
.
functional
.
pad
(
msa_mask
,
(
0
,
padding_size
))
msa_mask
=
torch
.
nn
.
functional
.
pad
(
msa_mask
,
(
0
,
padding_size
))
pair_mask
=
torch
.
nn
.
functional
.
pad
(
pair_mask
,
(
0
,
padding_size
,
0
,
padding_size
))
pair_mask
=
torch
.
nn
.
functional
.
pad
(
pair_mask
,
(
0
,
padding_size
,
0
,
padding_size
))
if
not
self
.
is_multimer
:
m
=
self
.
msa_stack
(
m
,
z
,
msa_mask
)
m
=
self
.
msa_stack
(
m
,
z
,
msa_mask
)
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
else
:
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
z_ori
=
z
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
m
=
self
.
msa_stack
(
m
,
z_ori
,
msa_mask
)
if
self
.
last_block
:
if
self
.
last_block
:
m
=
m
.
squeeze
(
0
)
m
=
m
.
squeeze
(
0
)
...
@@ -260,7 +268,6 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -260,7 +268,6 @@ class TemplatePairStackBlock(nn.Module):
single_templates
[
i
]
=
single
single_templates
[
i
]
=
single
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
if
self
.
last_block
:
if
self
.
last_block
:
z
=
gather
(
z
,
dim
=
1
)
z
=
gather
(
z
,
dim
=
1
)
z
=
z
[:,
:
-
padding_size
,
:
-
padding_size
,
:]
z
=
z
[:,
:
-
padding_size
,
:
-
padding_size
,
:]
...
...
fastfold/model/hub/alphafold.py
View file @
6d8b97ec
...
@@ -26,6 +26,7 @@ from fastfold.utils.feats import (
...
@@ -26,6 +26,7 @@ from fastfold.utils.feats import (
)
)
from
fastfold.model.nn.embedders
import
(
from
fastfold.model.nn.embedders
import
(
InputEmbedder
,
InputEmbedder
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
TemplatePairEmbedder
,
...
@@ -69,6 +70,11 @@ class AlphaFold(nn.Module):
...
@@ -69,6 +70,11 @@ class AlphaFold(nn.Module):
extra_msa_config
=
config
.
extra_msa
extra_msa_config
=
config
.
extra_msa
# Main trunk + structure module
# Main trunk + structure module
if
self
.
globals
.
is_multimer
:
self
.
input_embedder
=
InputEmbedderMultimer
(
**
config
[
"input_embedder"
],
)
else
:
self
.
input_embedder
=
InputEmbedder
(
self
.
input_embedder
=
InputEmbedder
(
**
config
[
"input_embedder"
],
**
config
[
"input_embedder"
],
)
)
...
...
fastfold/model/nn/embedders.py
View file @
6d8b97ec
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
from
typing
import
Tuple
,
Dict
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
from
fastfold.utils.tensor_utils
import
one_hot
from
fastfold.utils.tensor_utils
import
one_hot
...
@@ -125,6 +125,146 @@ class InputEmbedder(nn.Module):
...
@@ -125,6 +125,146 @@ class InputEmbedder(nn.Module):
return
msa_emb
,
pair_emb
return
msa_emb
,
pair_emb
class
InputEmbedderMultimer
(
nn
.
Module
):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def
__init__
(
self
,
tf_dim
:
int
,
msa_dim
:
int
,
c_z
:
int
,
c_m
:
int
,
max_relative_idx
:
int
,
use_chain_relative
:
bool
,
max_relative_chain
:
int
,
**
kwargs
,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super
(
InputEmbedderMultimer
,
self
).
__init__
()
self
.
tf_dim
=
tf_dim
self
.
msa_dim
=
msa_dim
self
.
c_z
=
c_z
self
.
c_m
=
c_m
self
.
linear_tf_z_i
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_z_j
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_m
=
Linear
(
tf_dim
,
c_m
)
self
.
linear_msa_m
=
Linear
(
msa_dim
,
c_m
)
# RPE stuff
self
.
max_relative_idx
=
max_relative_idx
self
.
use_chain_relative
=
use_chain_relative
self
.
max_relative_chain
=
max_relative_chain
if
self
.
use_chain_relative
:
self
.
no_bins
=
2
*
max_relative_idx
+
2
+
1
+
2
*
max_relative_chain
+
2
else
:
self
.
no_bins
=
2
*
max_relative_idx
+
1
self
.
linear_relpos
=
Linear
(
self
.
no_bins
,
c_z
)
def
relpos
(
self
,
batch
:
Dict
[
str
,
torch
.
Tensor
]):
pos
=
batch
[
"residue_index"
]
asym_id
=
batch
[
"asym_id"
]
asym_id_same
=
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
offset
=
pos
[...,
None
]
-
pos
[...,
None
,
:]
clipped_offset
=
torch
.
clamp
(
offset
+
self
.
max_relative_idx
,
0
,
2
*
self
.
max_relative_idx
)
rel_feats
=
[]
if
self
.
use_chain_relative
:
final_offset
=
torch
.
where
(
asym_id_same
,
clipped_offset
,
(
2
*
self
.
max_relative_idx
+
1
)
*
torch
.
ones_like
(
clipped_offset
),
)
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
final_offset
,
2
*
self
.
max_relative_idx
+
2
,
)
rel_feats
.
append
(
rel_pos
)
entity_id
=
batch
[
"entity_id"
]
entity_id_same
=
entity_id
[...,
None
]
==
entity_id
[...,
None
,
:]
rel_feats
.
append
(
entity_id_same
[...,
None
])
sym_id
=
batch
[
"sym_id"
]
rel_sym_id
=
sym_id
[...,
None
]
-
sym_id
[...,
None
,
:]
max_rel_chain
=
self
.
max_relative_chain
clipped_rel_chain
=
torch
.
clamp
(
rel_sym_id
+
max_rel_chain
,
0
,
2
*
max_rel_chain
,
)
final_rel_chain
=
torch
.
where
(
entity_id_same
,
clipped_rel_chain
,
(
2
*
max_rel_chain
+
1
)
*
torch
.
ones_like
(
clipped_rel_chain
),
)
rel_chain
=
torch
.
nn
.
functional
.
one_hot
(
final_rel_chain
.
long
(),
2
*
max_rel_chain
+
2
,
)
rel_feats
.
append
(
rel_chain
)
else
:
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
clipped_offset
,
2
*
self
.
max_relative_idx
+
1
,
)
rel_feats
.
append
(
rel_pos
)
rel_feat
=
torch
.
cat
(
rel_feats
,
dim
=-
1
).
to
(
self
.
linear_relpos
.
weight
.
dtype
)
return
self
.
linear_relpos
(
rel_feat
)
def
forward
(
self
,
batch
:
Dict
[
str
,
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tf
=
batch
[
"target_feat"
]
msa
=
batch
[
"msa_feat"
]
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
# [*, N_res, N_res, c_z]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
=
pair_emb
+
self
.
relpos
(
batch
)
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
tf_m
=
(
self
.
linear_tf_m
(
tf
)
.
unsqueeze
(
-
3
)
.
expand
(((
-
1
,)
*
len
(
tf
.
shape
[:
-
2
])
+
(
n_clust
,
-
1
,
-
1
)))
)
msa_emb
=
self
.
linear_msa_m
(
msa
)
+
tf_m
return
msa_emb
,
pair_emb
class
RecyclingEmbedder
(
nn
.
Module
):
class
RecyclingEmbedder
(
nn
.
Module
):
"""
"""
...
...
fastfold/model/nn/template.py
View file @
6d8b97ec
...
@@ -142,6 +142,7 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -142,6 +142,7 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n
:
int
,
pair_transition_n
:
int
,
dropout_rate
:
float
,
dropout_rate
:
float
,
inf
:
float
,
inf
:
float
,
is_multimer
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
super
(
TemplatePairStackBlock
,
self
).
__init__
()
super
(
TemplatePairStackBlock
,
self
).
__init__
()
...
@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
=
pair_transition_n
self
.
pair_transition_n
=
pair_transition_n
self
.
dropout_rate
=
dropout_rate
self
.
dropout_rate
=
dropout_rate
self
.
inf
=
inf
self
.
inf
=
inf
self
.
is_multimer
=
is_multimer
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_col
=
DropoutColumnwise
(
self
.
dropout_rate
)
self
.
dropout_col
=
DropoutColumnwise
(
self
.
dropout_rate
)
...
@@ -196,6 +198,7 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -196,6 +198,7 @@ class TemplatePairStackBlock(nn.Module):
single_templates_masks
=
[
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
]
]
if
not
self
.
is_multimer
:
for
i
in
range
(
len
(
single_templates
)):
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single_mask
=
single_templates_masks
[
i
]
...
@@ -233,6 +236,29 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -233,6 +236,29 @@ class TemplatePairStackBlock(nn.Module):
)
)
single_templates
[
i
]
=
single
single_templates
[
i
]
=
single
else
:
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single
=
single
+
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_out
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
pair_transition
(
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
)
single_templates
[
i
]
=
single
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
...
...
fastfold/workflow/template/fastfold_data_workflow.py
View file @
6d8b97ec
import
os
import
os
import
time
import
time
from
multiprocessing
import
cpu_count
from
multiprocessing
import
cpu_count
import
ray
from
ray
import
workflow
from
ray
import
workflow
from
fastfold.workflow.factory
import
JackHmmerFactory
,
HHSearchFactory
,
HHBlitsFactory
from
fastfold.workflow.factory
import
JackHmmerFactory
,
HHSearchFactory
,
HHBlitsFactory
from
fastfold.workflow
import
batch_run
from
fastfold.workflow
import
batch_run
...
@@ -80,6 +81,7 @@ class FastFoldDataWorkFlow:
...
@@ -80,6 +81,7 @@ class FastFoldDataWorkFlow:
print
(
"Workflow not found. Clean. Skipping"
)
print
(
"Workflow not found. Clean. Skipping"
)
pass
pass
# prepare alignment directory for alignment outputs
# prepare alignment directory for alignment outputs
if
alignment_dir
is
None
:
if
alignment_dir
is
None
:
alignment_dir
=
os
.
path
.
join
(
output_dir
,
"alignment"
)
alignment_dir
=
os
.
path
.
join
(
output_dir
,
"alignment"
)
...
...
inference.py
View file @
6d8b97ec
...
@@ -23,6 +23,7 @@ from datetime import date
...
@@ -23,6 +23,7 @@ from datetime import date
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
pickle
from
fastfold.model.hub
import
AlphaFold
from
fastfold.model.hub
import
AlphaFold
import
fastfold
import
fastfold
...
@@ -111,6 +112,7 @@ def inference_model(rank, world_size, result_q, batch, args):
...
@@ -111,6 +112,7 @@ def inference_model(rank, world_size, result_q, batch, args):
def
main
(
args
):
def
main
(
args
):
config
=
model_config
(
args
.
model_name
)
config
=
model_config
(
args
.
model_name
)
global_is_multimer
=
True
if
args
.
model_preset
==
"multimer"
else
False
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
mmcif_dir
=
args
.
template_mmcif_dir
,
...
@@ -147,6 +149,7 @@ def main(args):
...
@@ -147,6 +149,7 @@ def main(args):
seqs
,
tags
=
parse_fasta
(
fasta
)
seqs
,
tags
=
parse_fasta
(
fasta
)
for
tag
,
seq
in
zip
(
tags
,
seqs
):
for
tag
,
seq
in
zip
(
tags
,
seqs
):
print
(
f
"tag:
{
tag
}
seq:
{
seq
}
"
)
batch
=
[
None
]
batch
=
[
None
]
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
"tmp.fasta"
)
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
"tmp.fasta"
)
...
@@ -155,44 +158,48 @@ def main(args):
...
@@ -155,44 +158,48 @@ def main(args):
print
(
"Generating features..."
)
print
(
"Generating features..."
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
# if global_is_multimer:
if
not
os
.
path
.
exists
(
local_alignment_dir
):
# print("Multimer")
os
.
makedirs
(
local_alignment_dir
)
# else:
if
args
.
enable_workflow
:
# if (args.use_precomputed_alignments is None):
print
(
"Running alignment with ray workflow..."
)
# if not os.path.exists(local_alignment_dir):
alignment_data_workflow_runner
=
FastFoldDataWorkFlow
(
# os.makedirs(local_alignment_dir)
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
# if args.enable_workflow:
hhblits_binary_path
=
args
.
hhblits_binary_path
,
# print("Running alignment with ray workflow...")
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
# alignment_data_workflow_runner = FastFoldDataWorkFlow(
uniref90_database_path
=
args
.
uniref90_database_path
,
# jackhmmer_binary_path=args.jackhmmer_binary_path,
mgnify_database_path
=
args
.
mgnify_database_path
,
# hhblits_binary_path=args.hhblits_binary_path,
bfd_database_path
=
args
.
bfd_database_path
,
# hhsearch_binary_path=args.hhsearch_binary_path,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
# uniref90_database_path=args.uniref90_database_path,
pdb70_database_path
=
args
.
pdb70_database_path
,
# mgnify_database_path=args.mgnify_database_path,
use_small_bfd
=
use_small_bfd
,
# bfd_database_path=args.bfd_database_path,
no_cpus
=
args
.
cpus
,
# uniclust30_database_path=args.uniclust30_database_path,
)
# pdb70_database_path=args.pdb70_database_path,
t
=
time
.
perf_counter
()
# use_small_bfd=use_small_bfd,
alignment_data_workflow_runner
.
run
(
fasta_path
,
output_dir
=
output_dir_base
,
alignment_dir
=
local_alignment_dir
)
# no_cpus=args.cpus,
print
(
f
"Alignment data workflow time:
{
time
.
perf_counter
()
-
t
}
"
)
# )
else
:
# t = time.perf_counter()
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
# alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir)
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
# print(f"Alignment data workflow time: {time.perf_counter() - t}")
hhblits_binary_path
=
args
.
hhblits_binary_path
,
# else:
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
# alignment_runner = data_pipeline.AlignmentRunner(
uniref90_database_path
=
args
.
uniref90_database_path
,
# jackhmmer_binary_path=args.jackhmmer_binary_path,
mgnify_database_path
=
args
.
mgnify_database_path
,
# hhblits_binary_path=args.hhblits_binary_path,
bfd_database_path
=
args
.
bfd_database_path
,
# hhsearch_binary_path=args.hhsearch_binary_path,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
# uniref90_database_path=args.uniref90_database_path,
pdb70_database_path
=
args
.
pdb70_database_path
,
# mgnify_database_path=args.mgnify_database_path,
use_small_bfd
=
use_small_bfd
,
# bfd_database_path=args.bfd_database_path,
no_cpus
=
args
.
cpus
,
# uniclust30_database_path=args.uniclust30_database_path,
)
# pdb70_database_path=args.pdb70_database_path,
alignment_runner
.
run
(
fasta_path
,
local_alignment_dir
)
# use_small_bfd=use_small_bfd,
# no_cpus=args.cpus,
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
fasta_path
,
# )
alignment_dir
=
local_alignment_dir
)
# alignment_runner.run(fasta_path, local_alignment_dir)
# feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
# alignment_dir=local_alignment_dir)
feature_dict
=
pickle
.
load
(
open
(
"/home/lcmql/data/features_pdb1o5d.pkl"
,
"rb"
))
# Remove temporary FASTA file
# Remove temporary FASTA file
os
.
remove
(
fasta_path
)
os
.
remove
(
fasta_path
)
...
@@ -289,6 +296,14 @@ if __name__ == "__main__":
...
@@ -289,6 +296,14 @@ if __name__ == "__main__":
default
=
'full_dbs'
,
default
=
'full_dbs'
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
))
choices
=
(
'reduced_dbs'
,
'full_dbs'
))
parser
.
add_argument
(
'--data_random_seed'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--data_random_seed'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--model_preset"
,
type
=
str
,
default
=
"monomer"
,
choices
=
[
"monomer"
,
"multimer"
],
help
=
"Choose preset model configuration - the monomer model, the monomer model with "
"extra ensembling, monomer model with pTM head, or multimer model"
,
)
add_data_args
(
parser
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment