Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
54317fe4
Commit
54317fe4
authored
Sep 30, 2021
by
Gustaf Ahdritz
Browse files
Add pTM config options, spruce up pretrained script, add test script
parent
3fb44548
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
181 additions
and
90 deletions
+181
-90
config.py
config.py
+52
-18
openfold/model/model.py
openfold/model/model.py
+2
-2
run_pretrained_alphafold.py
run_pretrained_alphafold.py
+108
-70
scripts/run_unit_tests.sh
scripts/run_unit_tests.sh
+19
-0
No files found.
config.py
View file @
54317fe4
...
@@ -2,38 +2,72 @@ import copy
...
@@ -2,38 +2,72 @@ import copy
import
ml_collections
as
mlc
import
ml_collections
as
mlc
def
model_config
(
name
):
def
model_config
(
name
,
train
=
False
):
c
=
copy
.
deepcopy
(
config
)
c
=
copy
.
deepcopy
(
config
)
if
(
name
==
"model_3"
):
if
(
name
==
"model_1"
):
pass
elif
(
name
==
"model_2"
):
pass
elif
(
name
==
"model_3"
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
"model_4"
):
elif
(
name
==
"model_4"
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
"model_5"
):
elif
(
name
==
"model_5"
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
elif
(
name
==
"model_1_ptm"
):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"model_2_ptm"
):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"model_3_ptm"
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"model_4_ptm"
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
loss
.
tm
.
weight
=
0.1
elif
(
name
==
"model_5_ptm"
):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
loss
.
tm
.
weight
=
0.1
else
:
raise
ValueError
(
"Invalid model name"
)
return
c
if
(
train
):
c
.
globals
.
model
.
blocks_per_ckpt
=
1
c
.
globals
.
chunk_size
=
None
return
c
c_z
=
mlc
.
FieldReference
(
128
)
c_m
=
mlc
.
FieldReference
(
256
)
c_t
=
mlc
.
FieldReference
(
64
)
c_e
=
mlc
.
FieldReference
(
64
)
c_s
=
mlc
.
FieldReference
(
384
)
blocks_per_ckpt
=
mlc
.
FieldReference
(
1
,
field_type
=
int
)
chunk_size
=
mlc
.
FieldReference
(
None
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
)
eps
=
1e-8
c_z
=
mlc
.
FieldReference
(
128
,
field_type
=
int
)
inf
=
1e8
c_m
=
mlc
.
FieldReference
(
256
,
field_type
=
int
)
c_t
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
c_e
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
c_s
=
mlc
.
FieldReference
(
384
,
field_type
=
int
)
blocks_per_ckpt
=
mlc
.
FieldReference
(
None
,
field_type
=
int
)
chunk_size
=
mlc
.
FieldReference
(
4
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
inf
=
mlc
.
FieldReference
(
1e8
,
field_type
=
float
)
config
=
mlc
.
ConfigDict
({
config
=
mlc
.
ConfigDict
({
"model"
:
{
# Recurring FieldReferences that can be changed globally here
"globals"
:
{
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"c_m"
:
c_m
,
"c_t"
:
c_t
,
"c_t"
:
c_t
,
"c_e"
:
c_e
,
"c_e"
:
c_e
,
"c_s"
:
c_s
,
"c_s"
:
c_s
,
"no_cycles"
:
2
,
#4,
"eps"
:
eps
,
"inf"
:
inf
,
},
"model"
:
{
"no_cycles"
:
4
,
"_mask_trans"
:
False
,
"_mask_trans"
:
False
,
"input_embedder"
:
{
"input_embedder"
:
{
"tf_dim"
:
22
,
"tf_dim"
:
22
,
...
@@ -147,7 +181,7 @@ config = mlc.ConfigDict({
...
@@ -147,7 +181,7 @@ config = mlc.ConfigDict({
"no_qk_points"
:
4
,
"no_qk_points"
:
4
,
"no_v_points"
:
8
,
"no_v_points"
:
8
,
"dropout_rate"
:
0.1
,
"dropout_rate"
:
0.1
,
"no_blocks"
:
2
,
#
8,
"no_blocks"
:
8
,
"no_transition_layers"
:
1
,
"no_transition_layers"
:
1
,
"no_resnet_blocks"
:
2
,
"no_resnet_blocks"
:
2
,
"no_angles"
:
7
,
"no_angles"
:
7
,
...
@@ -168,7 +202,7 @@ config = mlc.ConfigDict({
...
@@ -168,7 +202,7 @@ config = mlc.ConfigDict({
"tm"
:
{
"tm"
:
{
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
Tru
e
,
"enabled"
:
Fals
e
,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_m"
:
c_m
,
...
@@ -245,7 +279,7 @@ config = mlc.ConfigDict({
...
@@ -245,7 +279,7 @@ config = mlc.ConfigDict({
"min_resolution"
:
0.1
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
#1e-8,
"eps"
:
eps
,
#1e-8,
"weight"
:
1.
0
,
"weight"
:
0
.
,
},
},
"eps"
:
eps
,
"eps"
:
eps
,
},
},
...
...
openfold/model/model.py
View file @
54317fe4
...
@@ -202,12 +202,12 @@ class AlphaFold(nn.Module):
...
@@ -202,12 +202,12 @@ class AlphaFold(nn.Module):
if
(
None
in
[
m_1_prev
,
z_prev
,
x_prev
]):
if
(
None
in
[
m_1_prev
,
z_prev
,
x_prev
]):
# [*, N, C_m]
# [*, N, C_m]
m_1_prev
=
m
.
new_zeros
(
m_1_prev
=
m
.
new_zeros
(
(
*
batch_dims
,
n
,
self
.
config
.
c_m
),
(
*
batch_dims
,
n
,
self
.
config
.
input_embedder
.
c_m
),
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z_prev
=
z
.
new_zeros
(
z_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
n
,
self
.
config
.
c_z
),
(
*
batch_dims
,
n
,
n
,
self
.
config
.
input_embedder
.
c_z
),
)
)
# [*, N, 3]
# [*, N, 3]
...
...
run_pretrained_alphafold.py
View file @
54317fe4
...
@@ -13,22 +13,23 @@
...
@@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
argparse
import
math
import
pickle
import
os
import
os
# A hack to get OpenMM and PyTorch to peacefully coexist
# A hack to get OpenMM and PyTorch to peacefully coexist
os
.
environ
[
"OPENMM_DEFAULT_PLATFORM"
]
=
"OpenCL"
os
.
environ
[
"OPENMM_DEFAULT_PLATFORM"
]
=
"OpenCL"
import
math
import
pickle
import
time
import
time
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
numpy
as
np
from
config
import
model_config
from
config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.np
import
residue_constants
,
protein
from
openfold.np
import
residue_constants
,
protein
import
openfold.np.relax.relax
as
relax
import
openfold.np.relax.relax
as
relax
from
openfold.utils.import_weights
import
(
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
import_jax_weights_
,
...
@@ -38,23 +39,23 @@ from openfold.utils.tensor_utils import (
...
@@ -38,23 +39,23 @@ from openfold.utils.tensor_utils import (
tensor_tree_map
,
tensor_tree_map
,
)
)
MODEL_NAME
=
"model_1"
MODEL_DEVICE
=
"cuda:4"
PARAM_PATH
=
"openfold/resources/params/params_model_1.npz"
FEAT_PATH
=
"tests/test_data/sample_feats.pickle"
FEAT_PATH
=
"tests/test_data/sample_feats.pickle"
config
=
model_config
(
MODEL_NAME
)
def
main
(
args
):
model
=
AlphaFold
(
config
.
model
)
config
=
model_config
(
args
.
model_name
)
model
=
model
.
eval
()
model
=
AlphaFold
(
config
.
model
)
import_jax_weights_
(
model
,
PARAM_PATH
)
model
=
model
.
eval
()
model
=
model
.
to
(
MODEL_DEVICE
)
import_jax_weights_
(
model
,
args
.
param_path
)
model
=
model
.
to
(
args
.
device
)
with
open
(
FEAT_PATH
,
"rb"
)
as
f
:
with
open
(
FEAT_PATH
,
"rb"
)
as
f
:
batch
=
pickle
.
load
(
f
)
batch
=
pickle
.
load
(
f
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
MODEL_DEVICE
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
device
)
for
k
,
v
in
batch
.
items
()
}
longs
=
[
longs
=
[
"aatype"
,
"aatype"
,
...
@@ -69,43 +70,80 @@ with torch.no_grad():
...
@@ -69,43 +70,80 @@ with torch.no_grad():
batch
[
l
]
=
batch
[
l
].
long
()
batch
[
l
]
=
batch
[
l
].
long
()
# Move the recycling dimension to the end
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
.
contiguous
()
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
batch
=
tensor_tree_map
(
move_dim
,
batch
)
batch
=
tensor_tree_map
(
move_dim
,
batch
)
make_contig
=
lambda
t
:
t
.
contiguous
()
batch
=
tensor_tree_map
(
make_contig
,
batch
)
t
=
time
.
time
()
t
=
time
.
time
()
out
=
model
(
batch
)
out
=
model
(
batch
)
print
(
f
"Inference time:
{
time
.
time
()
-
t
}
"
)
print
(
f
"Inference time:
{
time
.
time
()
-
t
}
"
)
# Toss out the recycling dimensions --- we don't need them anymore
# Toss out the recycling dimensions --- we don't need them anymore
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
plddt
=
out
[
"plddt"
]
plddt
=
out
[
"plddt"
]
mean_plddt
=
np
.
mean
(
plddt
)
mean_plddt
=
np
.
mean
(
plddt
)
plddt_b_factors
=
np
.
repeat
(
plddt_b_factors
=
np
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
)
unrelaxed_protein
=
protein
.
from_prediction
(
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
features
=
batch
,
result
=
out
,
result
=
out
,
b_factors
=
plddt_b_factors
b_factors
=
plddt_b_factors
)
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"7"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"7"
amber_relaxer
=
relax
.
AmberRelaxation
(
amber_relaxer
=
relax
.
AmberRelaxation
(
**
config
.
relax
**
config
.
relax
)
)
# Relax the prediction.
# Relax the prediction.
t
=
time
.
time
()
t
=
time
.
time
()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
print
(
f
"Relaxation time:
{
time
.
time
()
-
t
}
"
)
print
(
f
"Relaxation time:
{
time
.
time
()
-
t
}
"
)
# Save the relaxed PDB.
# Save the relaxed PDB.
output_dir
=
'.'
relaxed_output_path
=
os
.
path
.
join
(
relaxed_output_path
=
os
.
path
.
join
(
output_dir
,
f
'relaxed_
{
MODEL_NAME
}
.pdb'
)
args
.
output_dir
,
f
'relaxed_
{
args
.
model_name
}
.pdb'
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
f
.
write
(
relaxed_pdb_str
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cpu"
,
help
=
"""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"model_1"
,
help
=
"""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"""Name of the directory in which to output the prediction"""
)
parser
.
add_argument
(
"--param_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params"""
)
args
=
parser
.
parse_args
()
if
(
args
.
param_path
is
None
):
args
.
param_path
=
os
.
path
.
join
(
"openfold"
,
"resources"
,
"params"
,
"params_"
+
args
.
model_name
+
".npz"
)
main
(
args
)
scripts/run_unit_tests.sh
0 → 100755
View file @
54317fe4
#!/bin/bash
FLAGS
=
""
while
getopts
":v"
option
;
do
case
$option
in
v
)
FLAGS
=
$(
echo
"-v
$FLAGS
"
| xargs
)
# strip whitespace
;;
*
)
echo
"Invalid option:
${
option
}
"
;;
esac
done
python3
-m
unittest
$FLAGS
"
$@
"
||
\
echo
-e
"
\n
Test(s) failed. Make sure you've installed all Python dependencies."
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment