Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
a8b25e72
Unverified
Commit
a8b25e72
authored
Mar 10, 2023
by
Xuanlei Zhao
Committed by
GitHub
Mar 10, 2023
Browse files
init demo (#161)
parent
c9309ecf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
153 additions
and
0 deletions
+153
-0
demo.py
demo.py
+153
-0
No files found.
demo.py
0 → 100644
View file @
a8b25e72
# Copyright 2023 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
time
import
fastfold
import
numpy
as
np
import
torch
import
torch.multiprocessing
as
mp
from
fastfold.config
import
model_config
from
fastfold.data
import
data_transforms
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.model.hub
import
AlphaFold
from
fastfold.utils.inject_fastnn
import
inject_fastnn
from
fastfold.utils.tensor_utils
import
tensor_tree_map
if
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
>=
1
and
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
>
11
:
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
def
random_template_feats
(
n_templ
,
n
):
b
=
[]
batch
=
{
"template_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
)),
"template_pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
)),
"template_pseudo_beta"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
3
),
"template_aatype"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_templ
,
n
)),
"template_all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
,
37
)),
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
10
,
"template_torsion_angles_sin_cos"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
,
2
),
"template_alt_torsion_angles_sin_cos"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
,
2
),
"template_torsion_angles_mask"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
),
}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
return
batch
def
random_extra_msa_feats
(
n_extra
,
n
):
b
=
[]
batch
=
{
"extra_msa"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
int64
),
"extra_has_deletion"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
"extra_deletion_value"
:
np
.
random
.
rand
(
*
b
,
n_extra
,
n
).
astype
(
np
.
float32
),
"extra_msa_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
}
return
batch
def
generate_batch
(
n_res
):
batch
=
{}
tf
=
torch
.
randint
(
21
,
size
=
(
n_res
,))
batch
[
"target_feat"
]
=
torch
.
nn
.
functional
.
one_hot
(
tf
,
22
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
128
,
n_res
,
49
))
t_feats
=
random_template_feats
(
4
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
5120
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
[
"msa_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
128
,
n_res
)).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
3
))
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
return
batch
def
inference_model
(
rank
,
world_size
,
result_q
,
batch
,
args
):
os
.
environ
[
'RANK'
]
=
str
(
rank
)
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
rank
)
os
.
environ
[
'WORLD_SIZE'
]
=
str
(
world_size
)
# init distributed for Dynamic Axial Parallelism
fastfold
.
distributed
.
init_dap
()
torch
.
cuda
.
set_device
(
rank
)
config
=
model_config
(
args
.
model_name
)
if
args
.
chunk_size
:
config
.
globals
.
chunk_size
=
args
.
chunk_size
config
.
globals
.
inplace
=
args
.
inplace
config
.
globals
.
is_multimer
=
False
model
=
AlphaFold
(
config
)
model
=
inject_fastnn
(
model
)
model
=
model
.
eval
()
model
=
model
.
cuda
()
set_chunk_size
(
model
.
globals
.
chunk_size
)
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
print
(
f
"Inference time:
{
time
.
perf_counter
()
-
t
}
"
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
result_q
.
put
(
out
)
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
def
inference_monomer_model
(
args
):
batch
=
generate_batch
(
args
.
n_res
)
manager
=
mp
.
Manager
()
result_q
=
manager
.
Queue
()
torch
.
multiprocessing
.
spawn
(
inference_model
,
nprocs
=
args
.
gpus
,
args
=
(
args
.
gpus
,
result_q
,
batch
,
args
))
out
=
result_q
.
get
()
# get unrelexed pdb and save
# batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
# plddt = out["plddt"]
# plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
# unrelaxed_protein = protein.from_prediction(features=batch,
# result=out,
# b_factors=plddt_b_factors)
# with open('demo_unrelex.pdb', 'w+') as fp:
# fp.write(unrelaxed_protein)
def
main
(
args
):
inference_monomer_model
(
args
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
1
,
help
=
"""Number of GPUs with which to run inference"""
)
parser
.
add_argument
(
"--n_res"
,
type
=
int
,
default
=
50
,
help
=
"virtual residue number of random data"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"model_1"
,
help
=
"model name of alphafold"
)
parser
.
add_argument
(
'--chunk_size'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--inplace'
,
default
=
False
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
main
(
args
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment