Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
9924e7be
Commit
9924e7be
authored
Jul 25, 2022
by
Shenggan
Browse files
use torch.multiprocess to launch multi-gpu inference
parent
f44557ed
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
94 additions
and
81 deletions
+94
-81
README.md
README.md
+2
-1
fastfold/distributed/core.py
fastfold/distributed/core.py
+1
-1
inference.py
inference.py
+91
-79
No files found.
README.md
View file @
9924e7be
...
...
@@ -72,8 +72,9 @@ model = inject_fastnn(model)
For Dynamic Axial Parallelism, you can refer to
`./inference.py`
. Here is an example of 2 GPUs parallel inference:
```
shell
torchrun
--nproc_per_node
=
2
inference.py target.fasta data/pdb_mmcif/mmcif_files/
\
python
inference.py target.fasta data/pdb_mmcif/mmcif_files/
\
--output_dir
./
\
--gpus
2
\
--uniref90_database_path
data/uniref90/uniref90.fasta
\
--mgnify_database_path
data/mgnify/mgy_clusters_2018_12.fa
\
--pdb70_database_path
data/pdb70/pdb70
\
...
...
fastfold/distributed/core.py
View file @
9924e7be
...
...
@@ -34,7 +34,7 @@ def init_dap(tensor_model_parallel_size_=None):
set_missing_distributed_environ
(
'RANK'
,
0
)
set_missing_distributed_environ
(
'LOCAL_RANK'
,
0
)
set_missing_distributed_environ
(
'MASTER_ADDR'
,
"localhost"
)
set_missing_distributed_environ
(
'MASTER_PORT'
,
-
1
)
set_missing_distributed_environ
(
'MASTER_PORT'
,
1
8417
)
colossalai
.
launch_from_torch
(
config
=
{
"parallel"
:
dict
(
tensor
=
dict
(
size
=
tensor_model_parallel_size_
))})
inference.py
View file @
9924e7be
...
...
@@ -22,6 +22,7 @@ from datetime import date
import
numpy
as
np
import
torch
import
torch.multiprocessing
as
mp
from
fastfold.model.hub
import
AlphaFold
import
fastfold
...
...
@@ -73,19 +74,39 @@ def add_data_args(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--release_dates_path'
,
type
=
str
,
default
=
None
)
def
main
(
args
):
def
inference_model
(
rank
,
world_size
,
result_q
,
batch
,
args
):
os
.
environ
[
'RANK'
]
=
str
(
rank
)
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
rank
)
os
.
environ
[
'WORLD_SIZE'
]
=
str
(
world_size
)
# init distributed for Dynamic Axial Parallelism
fastfold
.
distributed
.
init_dap
()
torch
.
cuda
.
set_device
(
rank
)
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
)
import_jax_weights_
(
model
,
args
.
param_path
,
version
=
args
.
model_name
)
model
=
inject_fastnn
(
model
)
model
=
model
.
eval
()
#script_preset_(model)
model
=
model
.
cuda
()
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
print
(
f
"Inference time:
{
time
.
perf_counter
()
-
t
}
"
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
result_q
.
put
(
out
)
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
def
main
(
args
):
config
=
model_config
(
args
.
model_name
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
...
...
@@ -124,96 +145,83 @@ def main(args):
for
tag
,
seq
in
zip
(
tags
,
seqs
):
batch
=
[
None
]
if
torch
.
distributed
.
get_rank
()
==
0
:
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
"tmp.fasta"
)
with
open
(
fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
print
(
"Generating features..."
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
)
alignment_runner
.
run
(
fasta_path
,
local_alignment_dir
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
fasta_path
,
alignment_dir
=
local_alignment_dir
)
# Remove temporary FASTA file
os
.
remove
(
fasta_path
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
"tmp.fasta"
)
with
open
(
fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
print
(
"Generating features..."
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
)
alignment_runner
.
run
(
fasta_path
,
local_alignment_dir
)
batch
=
[
processed_feature_dict
]
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
fasta_path
,
alignment_dir
=
local_alignment_dir
)
torch
.
distributed
.
broadcast_object_list
(
batch
,
src
=
0
)
batch
=
batch
[
0
]
# Remove temporary FASTA file
os
.
remove
(
fasta_path
)
print
(
"Executing model..."
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
)
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
batch
=
processed_feature_dict
t
=
time
.
perf_count
er
()
out
=
model
(
batch
)
print
(
f
"Inference time:
{
time
.
perf_counter
()
-
t
}
"
)
manager
=
mp
.
Manag
er
()
result_q
=
manager
.
Queue
(
)
torch
.
multiprocessing
.
spawn
(
inference_model
,
nprocs
=
args
.
gpus
,
args
=
(
args
.
gpus
,
result_q
,
batch
,
args
)
)
torch
.
distributed
.
barrier
()
out
=
result_q
.
get
()
if
torch
.
distributed
.
get_rank
()
==
0
:
# Toss out the recycling dimensions --- we don't need them anymore
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
# Toss out the recycling dimensions --- we don't need them anymore
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
plddt
=
out
[
"plddt"
]
mean_plddt
=
np
.
mean
(
plddt
)
plddt
=
out
[
"plddt"
]
mean_plddt
=
np
.
mean
(
plddt
)
plddt_b_factors
=
np
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
plddt_b_factors
=
np
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
)
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
)
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
True
,
**
config
.
relax
,
)
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
True
,
**
config
.
relax
,
)
# Relax the prediction.
t
=
time
.
perf_counter
()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
print
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Relax the prediction.
t
=
time
.
perf_counter
()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
print
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
torch
.
distributed
.
barrier
()
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
if
__name__
==
"__main__"
:
...
...
@@ -252,6 +260,10 @@ if __name__ == "__main__":
type
=
int
,
default
=
12
,
help
=
"""Number of CPUs with which to run alignment tools"""
)
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
1
,
help
=
"""Number of GPUs with which to run inference"""
)
parser
.
add_argument
(
'--preset'
,
type
=
str
,
default
=
'full_dbs'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment