Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
60d148bd
Commit
60d148bd
authored
Dec 10, 2022
by
zhuww
Browse files
fix some problems and add save .pkl
parent
6a41c3e7
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
3 deletions
+19
-3
fastfold/model/fastnn/kernel/softmax.py
fastfold/model/fastnn/kernel/softmax.py
+1
-1
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+0
-1
fastfold/model/nn/structure_module.py
fastfold/model/nn/structure_module.py
+0
-1
inference.py
inference.py
+18
-0
No files found.
fastfold/model/fastnn/kernel/softmax.py
View file @
60d148bd
...
...
@@ -4,7 +4,7 @@ import logging
import
torch
_triton_available
=
Tru
e
_triton_available
=
Fals
e
if
_triton_available
:
try
:
from
.triton.softmax
import
softmax_triton_kernel_wrapper
...
...
fastfold/model/hub/alphafold.py
View file @
60d148bd
...
...
@@ -399,7 +399,6 @@ class AlphaFold(nn.Module):
outputs
[
"single"
]
=
s
# Predict 3D structure
z
=
[
z
]
outputs_sm
=
self
.
structure_module
(
s
,
z
,
...
...
fastfold/model/nn/structure_module.py
View file @
60d148bd
...
...
@@ -787,7 +787,6 @@ class StructureModule(nn.Module):
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
del
z
s
=
self
.
ipa_dropout
(
s
)
torch
.
cuda
.
empty_cache
()
s
=
self
.
layer_norm_ipa
(
s
)
...
...
inference.py
View file @
60d148bd
...
...
@@ -21,6 +21,7 @@ import time
from
datetime
import
date
import
tempfile
import
contextlib
import
logging
import
numpy
as
np
import
torch
...
...
@@ -43,6 +44,10 @@ from fastfold.data.parsers import parse_fasta
from
fastfold.utils.import_weights
import
import_jax_weights_
from
fastfold.utils.tensor_utils
import
tensor_tree_map
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
if
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
>=
1
and
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
>
11
:
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
...
...
@@ -450,6 +455,15 @@ def inference_monomer_model(args):
# with open(relaxed_output_path, 'w') as f:
# f.write(relaxed_pdb_str)
if
(
args
.
save_outputs
):
output_dict_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_output_dict.pkl'
)
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -483,6 +497,10 @@ if __name__ == "__main__":
help
=
"""Path to model parameters. If None, parameters are selected
automatically according to the model name from
./data/params"""
)
parser
.
add_argument
(
"--save_outputs"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to save all model outputs, including embeddings, etc."
)
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
12
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment