Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
a37c8b4c
Unverified
Commit
a37c8b4c
authored
Aug 23, 2022
by
LuGY
Committed by
GitHub
Aug 23, 2022
Browse files
modify reading data way and add inference test (#50)
parent
03a24ef9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
148 additions
and
4 deletions
+148
-4
inference.py
inference.py
+3
-4
tests/test_data_utils.py
tests/test_data_utils.py
+62
-0
tests/test_inference.py
tests/test_inference.py
+83
-0
No files found.
inference.py
View file @
a37c8b4c
...
@@ -32,6 +32,7 @@ from fastfold.config import model_config
...
@@ -32,6 +32,7 @@ from fastfold.config import model_config
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
fastfold.data
import
data_pipeline
,
feature_pipeline
,
templates
from
fastfold.utils
import
inject_fastnn
from
fastfold.utils
import
inject_fastnn
from
fastfold.data.parsers
import
parse_fasta
from
fastfold.utils.import_weights
import
import_jax_weights_
from
fastfold.utils.import_weights
import
import_jax_weights_
from
fastfold.utils.tensor_utils
import
tensor_tree_map
from
fastfold.utils.tensor_utils
import
tensor_tree_map
...
@@ -141,10 +142,8 @@ def main(args):
...
@@ -141,10 +142,8 @@ def main(args):
# Gather input sequences
# Gather input sequences
with
open
(
args
.
fasta_path
,
"r"
)
as
fp
:
with
open
(
args
.
fasta_path
,
"r"
)
as
fp
:
lines
=
[
l
.
strip
()
for
l
in
fp
.
readlines
()]
fasta
=
fp
.
read
()
seqs
,
tags
=
parse_fasta
(
fasta
)
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
l
[
1
:]
for
l
in
tags
]
for
tag
,
seq
in
zip
(
tags
,
seqs
):
for
tag
,
seq
in
zip
(
tags
,
seqs
):
batch
=
[
None
]
batch
=
[
None
]
...
...
tests/test_data_utils.py
0 → 100644
View file @
a37c8b4c
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
b
=
[]
if
batch_size
is
not
None
:
b
.
append
(
batch_size
)
batch
=
{
"template_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
)),
"template_pseudo_beta_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
)),
"template_pseudo_beta"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
3
),
"template_aatype"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_templ
,
n
)),
"template_all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_templ
,
n
,
37
)
),
"template_all_atom_positions"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
37
,
3
)
*
10
,
"template_torsion_angles_sin_cos"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
,
2
),
"template_alt_torsion_angles_sin_cos"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
,
2
),
"template_torsion_angles_mask"
:
np
.
random
.
rand
(
*
b
,
n_templ
,
n
,
7
),
}
batch
=
{
k
:
v
.
astype
(
np
.
float32
)
for
k
,
v
in
batch
.
items
()}
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
return
batch
def
random_extra_msa_feats
(
n_extra
,
n
,
batch_size
=
None
):
b
=
[]
if
batch_size
is
not
None
:
b
.
append
(
batch_size
)
batch
=
{
"extra_msa"
:
np
.
random
.
randint
(
0
,
22
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
int64
),
"extra_has_deletion"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
"extra_deletion_value"
:
np
.
random
.
rand
(
*
b
,
n_extra
,
n
).
astype
(
np
.
float32
),
"extra_msa_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
}
return
batch
\ No newline at end of file
tests/test_inference.py
0 → 100644
View file @
a37c8b4c
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
import
torch
import
ml_collections
as
mlc
import
fastfold
from
fastfold.model.hub
import
AlphaFold
from
fastfold.config
import
model_config
from
fastfold.model.fastnn
import
set_chunk_size
from
fastfold.utils
import
inject_fastnn
from
test_data_utils
import
random_extra_msa_feats
,
random_template_feats
from
fastfold.data
import
data_transforms
from
fastfold.utils.tensor_utils
import
tensor_tree_map
consts
=
mlc
.
ConfigDict
(
{
"n_res"
:
11
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_extra"
:
17
,
}
)
def
inference
():
fastfold
.
distributed
.
init_dap
()
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_extra_seq
=
consts
.
n_extra
config
=
model_config
(
'model_1'
)
model
=
AlphaFold
(
config
)
model
=
inject_fastnn
(
model
)
model
.
eval
()
model
.
cuda
()
set_chunk_size
(
model
.
globals
.
chunk_size
)
batch
=
{}
tf
=
torch
.
randint
(
config
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
batch
[
"target_feat"
]
=
torch
.
nn
.
functional
.
one_hot
(
tf
,
config
.
model
.
input_embedder
.
tf_dim
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
config
.
model
.
input_embedder
.
msa_dim
))
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
[
"msa_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
float
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
config
.
data
.
common
.
max_recycling_iters
))
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
print
(
f
"Inference time:
{
time
.
perf_counter
()
-
t
}
"
)
if
__name__
==
"__main__"
:
inference
()
print
(
"Inference Test Passed!"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment