Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
6d8b97ec
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "6df2505bf0352a7580b33f17ce6844afe04fb7be"
Unverified
Commit
6d8b97ec
authored
Aug 31, 2022
by
Fazzie-Maqianli
Committed by
GitHub
Aug 31, 2022
Browse files
support multimer (#54)
parent
1efccb6c
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
519 additions
and
83 deletions
+519
-83
README.md
README.md
+1
-1
fastfold/config.py
fastfold/config.py
+93
-1
fastfold/data/tools/hmmsearch.py
fastfold/data/tools/hmmsearch.py
+148
-0
fastfold/distributed/comm.py
fastfold/distributed/comm.py
+1
-1
fastfold/model/fastnn/blocks.py
fastfold/model/fastnn/blocks.py
+15
-8
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+9
-3
fastfold/model/nn/embedders.py
fastfold/model/nn/embedders.py
+141
-1
fastfold/model/nn/template.py
fastfold/model/nn/template.py
+56
-30
fastfold/workflow/template/fastfold_data_workflow.py
fastfold/workflow/template/fastfold_data_workflow.py
+2
-0
inference.py
inference.py
+53
-38
No files found.
README.md
View file @
6d8b97ec
...
...
@@ -59,7 +59,7 @@ Run the following command to build a docker image from Dockerfile provided.
```
shell
cd
ColossalAI
docker build
-t
f
astfold ./docker
docker build
-t
F
astfold ./docker
```
Run the following command to start the docker container in interactive mode.
...
...
fastfold/config.py
View file @
6d8b97ec
...
...
@@ -74,6 +74,24 @@ def model_config(name, train=False, low_prec=False):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
name
==
"relax"
:
pass
elif
"multimer"
in
name
:
c
.
globals
.
is_multimer
=
True
c
.
data
.
predict
.
max_msa_clusters
=
252
# 128 for monomer
c
.
model
.
structure_module
.
trans_scale_factor
=
20
# 10 for monomer
for
k
,
v
in
multimer_model_config_update
.
items
():
c
.
model
[
k
]
=
v
c
.
data
.
common
.
unsupervised_features
.
extend
(
[
"msa_mask"
,
"seq_mask"
,
"asym_id"
,
"entity_id"
,
"sym_id"
,
]
)
else
:
raise
ValueError
(
"Invalid model name"
)
...
...
@@ -275,6 +293,7 @@ config = mlc.ConfigDict(
"c_e"
:
c_e
,
"c_s"
:
c_s
,
"eps"
:
eps
,
"is_multimer"
:
False
,
},
"model"
:
{
"_mask_trans"
:
False
,
...
...
@@ -494,4 +513,77 @@ config = mlc.ConfigDict(
},
"ema"
:
{
"decay"
:
0.999
},
}
)
\ No newline at end of file
)
multimer_model_config_update
=
{
"input_embedder"
:
{
"tf_dim"
:
21
,
"msa_dim"
:
49
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
"use_chain_relative"
:
True
,
},
"template"
:
{
"distogram"
:
{
"min_bin"
:
3.25
,
"max_bin"
:
50.75
,
"no_bins"
:
39
,
},
"template_pair_embedder"
:
{
"c_z"
:
c_z
,
"c_out"
:
64
,
"c_dgram"
:
39
,
"c_aatype"
:
22
,
},
"template_single_embedder"
:
{
"c_in"
:
34
,
"c_m"
:
c_m
,
},
"template_pair_stack"
:
{
"c_t"
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att"
:
16
,
"c_hidden_tri_mul"
:
64
,
"no_blocks"
:
2
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
},
"c_t"
:
c_t
,
"c_z"
:
c_z
,
"inf"
:
1e5
,
# 1e9,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
},
"heads"
:
{
"lddt"
:
{
"no_bins"
:
50
,
"c_in"
:
c_s
,
"c_hidden"
:
128
,
},
"distogram"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
},
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
tm_enabled
,
},
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_out"
:
22
,
},
"experimentally_resolved"
:
{
"c_s"
:
c_s
,
"c_out"
:
37
,
},
},
}
fastfold/data/tools/hmmsearch.py
0 → 100644
View file @
6d8b97ec
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import
os
import
subprocess
import
logging
from
typing
import
Optional
,
Sequence
from
fastfold.data
import
parsers
from
fastfold.data.tools
import
hmmbuild
from
fastfold.utils
import
general_utils
as
utils
class
Hmmsearch
(
object
):
"""Python wrapper of the hmmsearch binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
,
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
hmmbuild_runner
=
hmmbuild
.
Hmmbuild
(
binary_path
=
hmmbuild_binary_path
)
self
.
database_path
=
database_path
if
flags
is
None
:
# Default hmmsearch run settings.
flags
=
[
"--F1"
,
"0.1"
,
"--F2"
,
"0.1"
,
"--F3"
,
"0.1"
,
"--incE"
,
"100"
,
"-E"
,
"100"
,
"--domE"
,
"100"
,
"--incdomE"
,
"100"
,
]
self
.
flags
=
flags
if
not
os
.
path
.
exists
(
self
.
database_path
):
logging
.
error
(
"Could not find hmmsearch database %s"
,
database_path
)
raise
ValueError
(
f
"Could not find hmmsearch database
{
database_path
}
"
)
@
property
def
output_format
(
self
)
->
str
:
return
"sto"
@
property
def
input_format
(
self
)
->
str
:
return
"sto"
def
query
(
self
,
msa_sto
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
model_construction
=
"hand"
)
return
self
.
query_with_hmm
(
hmm
,
output_dir
)
def
query_with_hmm
(
self
,
hmm
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.hmm"
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
out_path
=
os
.
path
.
join
(
output_dir
,
"hmm_output.sto"
)
with
open
(
hmm_input_path
,
"w"
)
as
f
:
f
.
write
(
hmm
)
cmd
=
[
self
.
binary_path
,
"--noali"
,
# Don't include the alignment in stdout.
"--cpu"
,
"8"
,
]
# If adding flags, we have to do so before the output and input:
if
self
.
flags
:
cmd
.
extend
(
self
.
flags
)
cmd
.
extend
(
[
"-A"
,
out_path
,
hmm_input_path
,
self
.
database_path
,
]
)
logging
.
info
(
"Launching sub-process %s"
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
"hmmsearch (
{
os
.
path
.
basename
(
self
.
database_path
)
}
) query"
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
"hmmsearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
%
(
stdout
.
decode
(
"utf-8"
),
stderr
.
decode
(
"utf-8"
))
)
with
open
(
out_path
)
as
f
:
out_msa
=
f
.
read
()
return
out_msa
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
template_hits
=
parsers
.
parse_hmmsearch_sto
(
output_string
,
input_sequence
,
)
return
template_hits
fastfold/distributed/comm.py
View file @
6d8b97ec
...
...
@@ -43,7 +43,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
==
1
:
return
tensor
if
dim
==
1
:
if
dim
==
1
and
list
(
tensor
.
shape
)[
0
]
==
1
:
output_shape
=
list
(
tensor
.
shape
)
output_shape
[
1
]
*=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
...
...
fastfold/model/fastnn/blocks.py
View file @
6d8b97ec
...
...
@@ -32,7 +32,7 @@ from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_O
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
first_block
:
bool
,
last_block
:
bool
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
first_block
:
bool
,
last_block
:
bool
,
is_multimer
:
bool
=
False
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
first_block
=
first_block
...
...
@@ -41,6 +41,7 @@ class EvoformerBlock(nn.Module):
self
.
msa_stack
=
MSAStack
(
c_m
,
c_z
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
c_m
,
n_feat_out
=
c_z
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
c_z
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
...
...
@@ -73,12 +74,19 @@ class EvoformerBlock(nn.Module):
msa_mask
=
torch
.
nn
.
functional
.
pad
(
msa_mask
,
(
0
,
padding_size
))
pair_mask
=
torch
.
nn
.
functional
.
pad
(
pair_mask
,
(
0
,
padding_size
,
0
,
padding_size
))
m
=
self
.
msa_stack
(
m
,
z
,
msa_mask
)
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
if
not
self
.
is_multimer
:
m
=
self
.
msa_stack
(
m
,
z
,
msa_mask
)
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
else
:
z
=
z
+
self
.
communication
(
m
,
msa_mask
)
z_ori
=
z
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair_stack
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
m
=
self
.
msa_stack
(
m
,
z_ori
,
msa_mask
)
if
self
.
last_block
:
m
=
m
.
squeeze
(
0
)
...
...
@@ -260,7 +268,6 @@ class TemplatePairStackBlock(nn.Module):
single_templates
[
i
]
=
single
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
if
self
.
last_block
:
z
=
gather
(
z
,
dim
=
1
)
z
=
z
[:,
:
-
padding_size
,
:
-
padding_size
,
:]
...
...
fastfold/model/hub/alphafold.py
View file @
6d8b97ec
...
...
@@ -26,6 +26,7 @@ from fastfold.utils.feats import (
)
from
fastfold.model.nn.embedders
import
(
InputEmbedder
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
...
...
@@ -69,9 +70,14 @@ class AlphaFold(nn.Module):
extra_msa_config
=
config
.
extra_msa
# Main trunk + structure module
self
.
input_embedder
=
InputEmbedder
(
**
config
[
"input_embedder"
],
)
if
self
.
globals
.
is_multimer
:
self
.
input_embedder
=
InputEmbedderMultimer
(
**
config
[
"input_embedder"
],
)
else
:
self
.
input_embedder
=
InputEmbedder
(
**
config
[
"input_embedder"
],
)
self
.
recycling_embedder
=
RecyclingEmbedder
(
**
config
[
"recycling_embedder"
],
)
...
...
fastfold/model/nn/embedders.py
View file @
6d8b97ec
...
...
@@ -15,7 +15,7 @@
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
from
typing
import
Tuple
,
Dict
from
fastfold.model.nn.primitives
import
Linear
,
LayerNorm
from
fastfold.utils.tensor_utils
import
one_hot
...
...
@@ -125,6 +125,146 @@ class InputEmbedder(nn.Module):
return
msa_emb
,
pair_emb
class
InputEmbedderMultimer
(
nn
.
Module
):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def
__init__
(
self
,
tf_dim
:
int
,
msa_dim
:
int
,
c_z
:
int
,
c_m
:
int
,
max_relative_idx
:
int
,
use_chain_relative
:
bool
,
max_relative_chain
:
int
,
**
kwargs
,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super
(
InputEmbedderMultimer
,
self
).
__init__
()
self
.
tf_dim
=
tf_dim
self
.
msa_dim
=
msa_dim
self
.
c_z
=
c_z
self
.
c_m
=
c_m
self
.
linear_tf_z_i
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_z_j
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_m
=
Linear
(
tf_dim
,
c_m
)
self
.
linear_msa_m
=
Linear
(
msa_dim
,
c_m
)
# RPE stuff
self
.
max_relative_idx
=
max_relative_idx
self
.
use_chain_relative
=
use_chain_relative
self
.
max_relative_chain
=
max_relative_chain
if
self
.
use_chain_relative
:
self
.
no_bins
=
2
*
max_relative_idx
+
2
+
1
+
2
*
max_relative_chain
+
2
else
:
self
.
no_bins
=
2
*
max_relative_idx
+
1
self
.
linear_relpos
=
Linear
(
self
.
no_bins
,
c_z
)
def
relpos
(
self
,
batch
:
Dict
[
str
,
torch
.
Tensor
]):
pos
=
batch
[
"residue_index"
]
asym_id
=
batch
[
"asym_id"
]
asym_id_same
=
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
offset
=
pos
[...,
None
]
-
pos
[...,
None
,
:]
clipped_offset
=
torch
.
clamp
(
offset
+
self
.
max_relative_idx
,
0
,
2
*
self
.
max_relative_idx
)
rel_feats
=
[]
if
self
.
use_chain_relative
:
final_offset
=
torch
.
where
(
asym_id_same
,
clipped_offset
,
(
2
*
self
.
max_relative_idx
+
1
)
*
torch
.
ones_like
(
clipped_offset
),
)
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
final_offset
,
2
*
self
.
max_relative_idx
+
2
,
)
rel_feats
.
append
(
rel_pos
)
entity_id
=
batch
[
"entity_id"
]
entity_id_same
=
entity_id
[...,
None
]
==
entity_id
[...,
None
,
:]
rel_feats
.
append
(
entity_id_same
[...,
None
])
sym_id
=
batch
[
"sym_id"
]
rel_sym_id
=
sym_id
[...,
None
]
-
sym_id
[...,
None
,
:]
max_rel_chain
=
self
.
max_relative_chain
clipped_rel_chain
=
torch
.
clamp
(
rel_sym_id
+
max_rel_chain
,
0
,
2
*
max_rel_chain
,
)
final_rel_chain
=
torch
.
where
(
entity_id_same
,
clipped_rel_chain
,
(
2
*
max_rel_chain
+
1
)
*
torch
.
ones_like
(
clipped_rel_chain
),
)
rel_chain
=
torch
.
nn
.
functional
.
one_hot
(
final_rel_chain
.
long
(),
2
*
max_rel_chain
+
2
,
)
rel_feats
.
append
(
rel_chain
)
else
:
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
clipped_offset
,
2
*
self
.
max_relative_idx
+
1
,
)
rel_feats
.
append
(
rel_pos
)
rel_feat
=
torch
.
cat
(
rel_feats
,
dim
=-
1
).
to
(
self
.
linear_relpos
.
weight
.
dtype
)
return
self
.
linear_relpos
(
rel_feat
)
def
forward
(
self
,
batch
:
Dict
[
str
,
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tf
=
batch
[
"target_feat"
]
msa
=
batch
[
"msa_feat"
]
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
# [*, N_res, N_res, c_z]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
=
pair_emb
+
self
.
relpos
(
batch
)
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
tf_m
=
(
self
.
linear_tf_m
(
tf
)
.
unsqueeze
(
-
3
)
.
expand
(((
-
1
,)
*
len
(
tf
.
shape
[:
-
2
])
+
(
n_clust
,
-
1
,
-
1
)))
)
msa_emb
=
self
.
linear_msa_m
(
msa
)
+
tf_m
return
msa_emb
,
pair_emb
class
RecyclingEmbedder
(
nn
.
Module
):
"""
...
...
fastfold/model/nn/template.py
View file @
6d8b97ec
...
...
@@ -142,6 +142,7 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n
:
int
,
dropout_rate
:
float
,
inf
:
float
,
is_multimer
:
bool
=
False
,
**
kwargs
,
):
super
(
TemplatePairStackBlock
,
self
).
__init__
()
...
...
@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
=
pair_transition_n
self
.
dropout_rate
=
dropout_rate
self
.
inf
=
inf
self
.
is_multimer
=
is_multimer
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_col
=
DropoutColumnwise
(
self
.
dropout_rate
)
...
...
@@ -196,43 +198,67 @@ class TemplatePairStackBlock(nn.Module):
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
]
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single
=
single
+
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
if
not
self
.
is_multimer
:
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single
=
single
+
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
)
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_out
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
pair_transition
(
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_out
(
single
,
mask
=
single_mask
single_templates
[
i
]
=
single
else
:
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single
=
single
+
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_out
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
single
,
mask
=
single_mask
)
)
single
=
single
+
self
.
pair_transition
(
single
,
mask
=
single_mask
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
)
)
single
=
single
+
self
.
pair_transition
(
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
)
single_templates
[
i
]
=
single
single_templates
[
i
]
=
single
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
...
...
fastfold/workflow/template/fastfold_data_workflow.py
View file @
6d8b97ec
import
os
import
time
from
multiprocessing
import
cpu_count
import
ray
from
ray
import
workflow
from
fastfold.workflow.factory
import
JackHmmerFactory
,
HHSearchFactory
,
HHBlitsFactory
from
fastfold.workflow
import
batch_run
...
...
@@ -80,6 +81,7 @@ class FastFoldDataWorkFlow:
print
(
"Workflow not found. Clean. Skipping"
)
pass
# prepare alignment directory for alignment outputs
if
alignment_dir
is
None
:
alignment_dir
=
os
.
path
.
join
(
output_dir
,
"alignment"
)
...
...
inference.py
View file @
6d8b97ec
...
...
@@ -23,6 +23,7 @@ from datetime import date
import
numpy
as
np
import
torch
import
torch.multiprocessing
as
mp
import
pickle
from
fastfold.model.hub
import
AlphaFold
import
fastfold
...
...
@@ -111,6 +112,7 @@ def inference_model(rank, world_size, result_q, batch, args):
def
main
(
args
):
config
=
model_config
(
args
.
model_name
)
global_is_multimer
=
True
if
args
.
model_preset
==
"multimer"
else
False
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
...
...
@@ -147,6 +149,7 @@ def main(args):
seqs
,
tags
=
parse_fasta
(
fasta
)
for
tag
,
seq
in
zip
(
tags
,
seqs
):
print
(
f
"tag:
{
tag
}
seq:
{
seq
}
"
)
batch
=
[
None
]
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
"tmp.fasta"
)
...
...
@@ -155,44 +158,48 @@ def main(args):
print
(
"Generating features..."
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
if
args
.
enable_workflow
:
print
(
"Running alignment with ray workflow..."
)
alignment_data_workflow_runner
=
FastFoldDataWorkFlow
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
)
t
=
time
.
perf_counter
()
alignment_data_workflow_runner
.
run
(
fasta_path
,
output_dir
=
output_dir_base
,
alignment_dir
=
local_alignment_dir
)
print
(
f
"Alignment data workflow time:
{
time
.
perf_counter
()
-
t
}
"
)
else
:
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
)
alignment_runner
.
run
(
fasta_path
,
local_alignment_dir
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
fasta_path
,
alignment_dir
=
local_alignment_dir
)
# if global_is_multimer:
# print("Multimer")
# else:
# if (args.use_precomputed_alignments is None):
# if not os.path.exists(local_alignment_dir):
# os.makedirs(local_alignment_dir)
# if args.enable_workflow:
# print("Running alignment with ray workflow...")
# alignment_data_workflow_runner = FastFoldDataWorkFlow(
# jackhmmer_binary_path=args.jackhmmer_binary_path,
# hhblits_binary_path=args.hhblits_binary_path,
# hhsearch_binary_path=args.hhsearch_binary_path,
# uniref90_database_path=args.uniref90_database_path,
# mgnify_database_path=args.mgnify_database_path,
# bfd_database_path=args.bfd_database_path,
# uniclust30_database_path=args.uniclust30_database_path,
# pdb70_database_path=args.pdb70_database_path,
# use_small_bfd=use_small_bfd,
# no_cpus=args.cpus,
# )
# t = time.perf_counter()
# alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir)
# print(f"Alignment data workflow time: {time.perf_counter() - t}")
# else:
# alignment_runner = data_pipeline.AlignmentRunner(
# jackhmmer_binary_path=args.jackhmmer_binary_path,
# hhblits_binary_path=args.hhblits_binary_path,
# hhsearch_binary_path=args.hhsearch_binary_path,
# uniref90_database_path=args.uniref90_database_path,
# mgnify_database_path=args.mgnify_database_path,
# bfd_database_path=args.bfd_database_path,
# uniclust30_database_path=args.uniclust30_database_path,
# pdb70_database_path=args.pdb70_database_path,
# use_small_bfd=use_small_bfd,
# no_cpus=args.cpus,
# )
# alignment_runner.run(fasta_path, local_alignment_dir)
# feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
# alignment_dir=local_alignment_dir)
feature_dict
=
pickle
.
load
(
open
(
"/home/lcmql/data/features_pdb1o5d.pkl"
,
"rb"
))
# Remove temporary FASTA file
os
.
remove
(
fasta_path
)
...
...
@@ -289,6 +296,14 @@ if __name__ == "__main__":
default
=
'full_dbs'
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
))
parser
.
add_argument
(
'--data_random_seed'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--model_preset"
,
type
=
str
,
default
=
"monomer"
,
choices
=
[
"monomer"
,
"multimer"
],
help
=
"Choose preset model configuration - the monomer model, the monomer model with "
"extra ensembling, monomer model with pTM head, or multimer model"
,
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment