Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
02fc4376
Commit
02fc4376
authored
Jan 10, 2022
by
Gustaf Ahdritz
Browse files
Beef up alignment script
parent
53bb9c10
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
168 additions
and
38 deletions
+168
-38
scripts/precompute_alignments.py
scripts/precompute_alignments.py
+168
-38
No files found.
scripts/precompute_alignments.py
View file @
02fc4376
import
argparse
from
functools
import
partial
import
json
import
logging
import
os
import
threading
from
multiprocessing
import
cpu_count
from
shutil
import
copyfile
import
tempfile
import
openfold.data.mmcif_parsing
as
mmcif_parsing
...
...
@@ -10,30 +15,57 @@ from openfold.np import protein, residue_constants
from
utils
import
add_data_args
#python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ data/uniref90/uniref90.fasta data/mgnify/mgy_clusters_2018_12.fa data/pdb70/pdb70 data/pdb_mmcif/mmcif_files/ data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt --cpus 16 --jackhmmer_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/jackhmmer --hhblits_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/hhblits --hhsearch_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/hhsearch --kalign_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/kalign
logging
.
basicConfig
(
level
=
logging
.
DEBU
G
)
logging
.
basicConfig
(
level
=
logging
.
WARNIN
G
)
def
main
(
args
):
# Build the alignment tool runner
alignment_runner
=
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
args
.
bfd_database_path
is
None
,
no_cpus
=
args
.
cpus
,
)
def
run_seq_group_alignments
(
seq_groups
,
alignment_runner
,
args
):
dirs
=
set
(
os
.
listdir
(
args
.
output_dir
))
for
seq
,
names
in
seq_groups
:
first_name
=
names
[
0
]
alignment_dir
=
os
.
path
.
join
(
args
.
output_dir
,
first_name
)
try
:
os
.
makedirs
(
alignment_dir
)
except
Exception
as
e
:
logging
.
warning
(
f
"Failed to create directory for
{
first_name
}
with exception
{
e
}
..."
)
continue
fd
,
fasta_path
=
tempfile
.
mkstemp
(
suffix
=
".fasta"
)
with
os
.
fdopen
(
fd
,
'w'
)
as
fp
:
fp
.
write
(
f
'>query
\n
{
seq
}
'
)
try
:
alignment_runner
.
run
(
fasta_path
,
alignment_dir
)
except
:
logging
.
warning
(
f
"Failed to run alignments for
{
first_name
}
. Skipping..."
)
os
.
remove
(
fasta_path
)
os
.
rmdir
(
alignment_dir
)
continue
os
.
remove
(
fasta_path
)
for
name
in
names
[
1
:]:
if
(
name
in
dirs
):
logging
.
warning
(
f
'
{
name
}
has already been processed. Skipping...'
)
continue
cp_dir
=
os
.
path
.
join
(
args
.
output_dir
,
name
)
os
.
makedirs
(
cp_dir
)
for
f
in
os
.
listdir
(
alignment_dir
):
copyfile
(
os
.
path
.
join
(
alignment_dir
,
f
),
os
.
path
.
join
(
cp_dir
,
f
))
for
f
in
os
.
listdir
(
args
.
input_dir
):
def
parse_and_align
(
files
,
alignment_runner
,
args
):
for
f
in
files
:
path
=
os
.
path
.
join
(
args
.
input_dir
,
f
)
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
seq
s
=
{}
seq
_group_dict
=
{}
if
(
f
.
endswith
(
'.cif'
)):
with
open
(
path
,
'r'
)
as
fp
:
mmcif_str
=
fp
.
read
()
...
...
@@ -47,9 +79,10 @@ def main(args):
else
:
continue
mmcif
=
mmcif
.
mmcif_object
for
k
,
v
in
mmcif
.
chain_to_seqres
.
items
():
chain_id
=
'_'
.
join
([
file_id
,
k
])
seqs
[
chain_id
]
=
v
for
chain_letter
,
seq
in
mmcif
.
chain_to_seqres
.
items
():
chain_id
=
'_'
.
join
([
file_id
,
chain_letter
])
l
=
seq_group_dict
.
setdefault
(
seq
,
[])
l
.
append
(
chain_id
)
elif
(
f
.
endswith
(
'.fasta'
)
or
f
.
endswith
(
'.fa'
)):
with
open
(
path
,
'r'
)
as
fp
:
fasta_str
=
fp
.
read
()
...
...
@@ -61,7 +94,7 @@ def main(args):
else
:
logging
.
warning
(
msg
)
input_sequence
=
input_seqs
[
0
]
seq
s
[
file_id
]
=
input_sequence
seq
_group_dict
[
input_sequence
]
=
[
file_id
]
elif
(
f
.
endswith
(
'.core'
)):
with
open
(
path
,
'r'
)
as
fp
:
core_str
=
fp
.
read
()
...
...
@@ -71,27 +104,114 @@ def main(args):
residue_constants
.
restypes_with_x
[
aatype
[
i
]]
for
i
in
range
(
len
(
aatype
))
])
seq
s
[
file_id
]
=
seq
seq
_group_dict
[
seq
]
=
[
file_id
]
else
:
continue
for
name
,
seq
in
seqs
.
items
():
alignment_dir
=
os
.
path
.
join
(
args
.
output_dir
,
name
)
if
(
os
.
path
.
isdir
(
alignment_dir
)):
logging
.
info
(
f
'
{
f
}
has already been processed. Skipping...'
)
continue
os
.
makedirs
(
alignment_dir
)
seq_group_tuples
=
[(
k
,
v
)
for
k
,
v
in
seq_group_dict
.
items
()]
run_seq_group_alignments
(
seq_group_tuples
,
alignment_runner
,
args
)
fd
,
fasta_path
=
tempfile
.
mkstemp
(
suffix
=
".fasta"
)
with
os
.
fdopen
(
fd
,
'w'
)
as
fp
:
fp
.
write
(
f
'>query
\n
{
seq
}
'
)
alignment_runner
.
run
(
fasta_path
,
alignment_dir
)
def
main
(
args
):
# Build the alignment tool runner
alignment_runner
=
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
args
.
bfd_database_path
is
None
,
no_cpus
=
args
.
cpus_per_task
,
)
os
.
remove
(
fasta_path
)
files
=
list
(
os
.
listdir
(
args
.
input_dir
))
# Do some filtering
if
(
args
.
mmcif_cache
is
not
None
):
with
open
(
args
.
mmcif_cache
,
"r"
)
as
fp
:
cache
=
json
.
load
(
fp
)
else
:
cache
=
None
if
(
cache
is
not
None
and
args
.
filter
):
dirs
=
set
(
os
.
listdir
(
args
.
output_dir
))
def
prot_is_done
(
f
):
prot_id
=
os
.
path
.
splitext
(
f
)[
0
]
if
(
prot_id
in
cache
):
chain_ids
=
cache
[
prot_id
][
"chain_ids"
]
for
c
in
chain_ids
:
full_name
=
prot_id
+
"_"
+
c
if
(
not
full_name
in
dirs
):
return
False
else
:
return
False
return
True
files
=
[
f
for
f
in
files
if
not
prot_is_done
(
f
)]
def
split_up_arglist
(
arglist
):
# Split up the survivors
if
(
os
.
environ
.
get
(
"SLURM_JOB_NUM_NODES"
,
0
)):
num_nodes
=
int
(
os
.
environ
[
"SLURM_JOB_NUM_NODES"
])
if
(
num_nodes
>
1
):
node_id
=
int
(
os
.
environ
[
"SLURM_NODEID"
])
logging
.
warning
(
f
"Num nodes:
{
num_nodes
}
"
)
logging
.
warning
(
f
"Node ID:
{
node_id
}
"
)
arglist
=
arglist
[
node_id
::
num_nodes
]
t_arglist
=
[]
for
i
in
range
(
args
.
no_tasks
):
t_arglist
.
append
(
arglist
[
i
::
args
.
no_tasks
])
return
t_arglist
if
(
cache
is
not
None
and
"seqs"
in
next
(
iter
(
cache
.
values
()))):
seq_group_dict
=
{}
for
f
in
files
:
prot_id
=
os
.
path
.
splitext
(
f
)[
0
]
if
(
prot_id
in
cache
):
prot_cache
=
cache
[
prot_id
]
chains_seqs
=
zip
(
prot_cache
[
"chain_ids"
],
prot_cache
[
"seqs"
]
)
for
chain
,
seq
in
chains_seqs
:
chain_name
=
prot_id
+
"_"
+
chain
if
(
chain_name
not
in
dirs
):
l
=
seq_group_dict
.
setdefault
(
seq
,
[])
l
.
append
(
chain_name
)
func
=
partial
(
run_seq_group_alignments
,
alignment_runner
=
alignment_runner
,
args
=
args
)
seq_groups
=
[(
k
,
v
)
for
k
,
v
in
seq_group_dict
.
items
()]
# Sort them by group length so the tasks are approximately balanced
seq_groups
=
sorted
(
seq_groups
,
key
=
lambda
x
:
len
(
x
[
1
]))
task_arglist
=
[[
a
]
for
a
in
split_up_arglist
(
seq_groups
)]
else
:
func
=
partial
(
parse_and_align
,
alignment_runner
=
alignment_runner
,
args
=
args
,
)
task_arglist
=
[[
a
]
for
a
in
split_up_arglist
(
files
)]
threads
=
[]
for
i
,
task_args
in
enumerate
(
task_arglist
):
print
(
f
"Started thread
{
i
}
..."
)
t
=
threading
.
Thread
(
target
=
func
,
args
=
task_args
)
threads
.
append
(
t
)
t
.
start
()
for
t
in
threads
:
t
.
join
()
if
__name__
==
"__main__"
:
...
...
@@ -111,9 +231,19 @@ if __name__ == "__main__":
help
=
"Whether to crash on parsing errors"
)
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
4
,
"--cpus
_per_task
"
,
type
=
int
,
default
=
cpu_count
()
,
help
=
"Number of CPUs to use"
)
parser
.
add_argument
(
"--mmcif_cache"
,
type
=
str
,
default
=
None
,
help
=
"Path to mmCIF cache. Used to filter files to be parsed"
)
parser
.
add_argument
(
"--no_tasks"
,
type
=
int
,
default
=
1
,
)
parser
.
add_argument
(
"--filter"
,
type
=
bool
,
default
=
True
,
)
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment