Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
39a6d0e6
Commit
39a6d0e6
authored
Apr 09, 2023
by
Christina Floristean
Browse files
Merging in main branch
parents
d8ee9c5f
84659c93
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1711 additions
and
603 deletions
+1711
-603
openfold/utils/all_atom_multimer.py
openfold/utils/all_atom_multimer.py
+2
-0
openfold/utils/callbacks.py
openfold/utils/callbacks.py
+2
-2
openfold/utils/checkpointing.py
openfold/utils/checkpointing.py
+12
-4
openfold/utils/chunk_utils.py
openfold/utils/chunk_utils.py
+428
-0
openfold/utils/exponential_moving_average.py
openfold/utils/exponential_moving_average.py
+2
-1
openfold/utils/feats.py
openfold/utils/feats.py
+8
-4
openfold/utils/geometry/quat_rigid.py
openfold/utils/geometry/quat_rigid.py
+1
-1
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+88
-92
openfold/utils/kernel/csrc/softmax_cuda_stub.cpp
openfold/utils/kernel/csrc/softmax_cuda_stub.cpp
+36
-0
openfold/utils/logger.py
openfold/utils/logger.py
+2
-3
openfold/utils/loss.py
openfold/utils/loss.py
+18
-11
openfold/utils/precision_utils.py
openfold/utils/precision_utils.py
+23
-0
openfold/utils/rigid_utils.py
openfold/utils/rigid_utils.py
+41
-49
openfold/utils/script_utils.py
openfold/utils/script_utils.py
+256
-0
openfold/utils/superimposition.py
openfold/utils/superimposition.py
+0
-1
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+14
-301
openfold/utils/trace_utils.py
openfold/utils/trace_utils.py
+422
-0
run_pretrained_openfold.py
run_pretrained_openfold.py
+272
-134
scripts/alignment_db_scripts/create_alignment_db.py
scripts/alignment_db_scripts/create_alignment_db.py
+47
-0
scripts/alignment_db_scripts/unify_alignment_db_indices.py
scripts/alignment_db_scripts/unify_alignment_db_indices.py
+37
-0
No files found.
openfold/utils/all_atom_multimer.py
View file @
39a6d0e6
...
...
@@ -17,9 +17,11 @@ from functools import partial
from
typing
import
Dict
,
Text
,
Tuple
import
torch
import
jax.numpy
as
jnp
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils
import
geometry
,
tensor_utils
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
import
numpy
as
np
...
...
openfold/utils/callbacks.py
View file @
39a6d0e6
...
...
@@ -6,8 +6,8 @@ class EarlyStoppingVerbose(EarlyStopping):
The default EarlyStopping callback's verbose mode is too verbose.
This class outputs a message only when it's getting ready to stop.
"""
def
_evalute_stopping_criteria
(
self
,
*
args
):
should_stop
,
reason
=
super
().
_evalute_stopping_criteria
(
*
args
)
def
_evalute_stopping_criteria
(
self
,
*
args
,
**
kwargs
):
should_stop
,
reason
=
super
().
_evalute_stopping_criteria
(
*
args
,
**
kwargs
)
if
(
should_stop
):
rank_zero_info
(
f
"
{
reason
}
\n
"
)
...
...
openfold/utils/checkpointing.py
View file @
39a6d0e6
...
...
@@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
from
typing
import
Any
,
Tuple
,
List
,
Callable
,
Optional
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
if
(
deepspeed_is_installed
):
import
deepspeed
import
deepspeed
import
torch
import
torch.utils.checkpoint
from
typing
import
Any
,
Tuple
,
List
,
Callable
,
Optional
BLOCK_ARG
=
Any
...
...
@@ -23,7 +27,11 @@ BLOCK_ARGS = List[BLOCK_ARG]
def
get_checkpoint_fn
():
if
(
deepspeed
.
checkpointing
.
is_configured
()):
deepspeed_is_configured
=
(
deepspeed_is_installed
and
deepspeed
.
checkpointing
.
is_configured
()
)
if
(
deepspeed_is_configured
):
checkpoint
=
deepspeed
.
checkpointing
.
checkpoint
else
:
checkpoint
=
torch
.
utils
.
checkpoint
.
checkpoint
...
...
@@ -73,7 +81,7 @@ def checkpoint_blocks(
# Avoids mishaps when the blocks take just one argument
args
=
wrap
(
args
)
if
blocks_per_ckpt
is
None
:
if
blocks_per_ckpt
is
None
or
not
torch
.
is_grad_enabled
()
:
return
exec
(
blocks
,
args
)
elif
blocks_per_ckpt
<
1
or
blocks_per_ckpt
>
len
(
blocks
):
raise
ValueError
(
"blocks_per_ckpt must be between 1 and len(blocks)"
)
...
...
openfold/utils/chunk_utils.py
0 → 100644
View file @
39a6d0e6
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
logging
import
math
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
import
torch
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
)
def
_fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
_fetch_dims
(
v
))
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
_fetch_dims
(
t
))
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
@
torch
.
jit
.
ignore
def
_flat_idx_to_idx
(
flat_idx
:
int
,
dims
:
Tuple
[
int
],
)
->
Tuple
[
int
]:
idx
=
[]
for
d
in
reversed
(
dims
):
idx
.
append
(
flat_idx
%
d
)
flat_idx
=
flat_idx
//
d
return
tuple
(
reversed
(
idx
))
@
torch
.
jit
.
ignore
def
_get_minimal_slice_set
(
start
:
Sequence
[
int
],
end
:
Sequence
[
int
],
dims
:
int
,
start_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
end_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
)
->
Sequence
[
Tuple
[
int
]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def
reduce_edge_list
(
l
):
tally
=
1
for
i
in
range
(
len
(
l
)):
reversed_idx
=
-
1
*
(
i
+
1
)
l
[
reversed_idx
]
*=
tally
tally
=
l
[
reversed_idx
]
if
(
start_edges
is
None
):
start_edges
=
[
s
==
0
for
s
in
start
]
reduce_edge_list
(
start_edges
)
if
(
end_edges
is
None
):
end_edges
=
[
e
==
(
d
-
1
)
for
e
,
d
in
zip
(
end
,
dims
)]
reduce_edge_list
(
end_edges
)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if
(
len
(
start
)
==
0
):
return
[
tuple
()]
elif
(
len
(
start
)
==
1
):
return
[(
slice
(
start
[
0
],
end
[
0
]
+
1
),)]
slices
=
[]
path
=
[]
# Dimensions common to start and end can be selected directly
for
s
,
e
in
zip
(
start
,
end
):
if
(
s
==
e
):
path
.
append
(
slice
(
s
,
s
+
1
))
else
:
break
path
=
tuple
(
path
)
divergence_idx
=
len
(
path
)
# start == end, and we're done
if
(
divergence_idx
==
len
(
dims
)):
return
[
tuple
(
path
)]
def
upper
():
sdi
=
start
[
divergence_idx
]
return
[
path
+
(
slice
(
sdi
,
sdi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
start
[
divergence_idx
+
1
:],
[
d
-
1
for
d
in
dims
[
divergence_idx
+
1
:]],
dims
[
divergence_idx
+
1
:],
start_edges
=
start_edges
[
divergence_idx
+
1
:],
end_edges
=
[
1
for
_
in
end_edges
[
divergence_idx
+
1
:]]
)
]
def
lower
():
edi
=
end
[
divergence_idx
]
return
[
path
+
(
slice
(
edi
,
edi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
[
0
for
_
in
start
[
divergence_idx
+
1
:]],
end
[
divergence_idx
+
1
:],
dims
[
divergence_idx
+
1
:],
start_edges
=
[
1
for
_
in
start_edges
[
divergence_idx
+
1
:]],
end_edges
=
end_edges
[
divergence_idx
+
1
:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if
(
start_edges
[
divergence_idx
]
and
end_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]
+
1
),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif
(
start_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
# Analogous to the previous case, but the top is ragged this time
elif
(
end_edges
[
divergence_idx
]):
slices
.
extend
(
upper
())
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]
+
1
),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else
:
slices
.
extend
(
upper
())
middle_ground
=
end
[
divergence_idx
]
-
start
[
divergence_idx
]
if
(
middle_ground
>
1
):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
return
[
tuple
(
s
)
for
s
in
slices
]
@
torch
.
jit
.
ignore
def
_chunk_slice
(
t
:
torch
.
Tensor
,
flat_start
:
int
,
flat_end
:
int
,
no_batch_dims
:
int
,
)
->
torch
.
Tensor
:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims
=
t
.
shape
[:
no_batch_dims
]
start_idx
=
list
(
_flat_idx_to_idx
(
flat_start
,
batch_dims
))
# _get_minimal_slice_set is inclusive
end_idx
=
list
(
_flat_idx_to_idx
(
flat_end
-
1
,
batch_dims
))
# Get an ordered list of slices to perform
slices
=
_get_minimal_slice_set
(
start_idx
,
end_idx
,
batch_dims
,
)
sliced_tensors
=
[
t
[
s
]
for
s
in
slices
]
return
torch
.
cat
(
[
s
.
view
((
-
1
,)
+
t
.
shape
[
no_batch_dims
:])
for
s
in
sliced_tensors
]
)
def
chunk_layer
(
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
no_batch_dims
:
int
,
low_mem
:
bool
=
False
,
_out
:
Any
=
None
,
_add_into_out
:
bool
=
False
,
)
->
Any
:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
_fetch_dims
(
inputs
)]
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
def
_prep_inputs
(
t
):
if
(
not
low_mem
):
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
else
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
return
t
prepped_inputs
=
tensor_tree_map
(
_prep_inputs
,
inputs
)
prepped_outputs
=
None
if
(
_out
is
not
None
):
reshape_fn
=
lambda
t
:
t
.
view
([
-
1
]
+
list
(
t
.
shape
[
no_batch_dims
:]))
prepped_outputs
=
tensor_tree_map
(
reshape_fn
,
_out
)
flat_batch_dim
=
1
for
d
in
orig_batch_dims
:
flat_batch_dim
*=
d
no_chunks
=
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
i
=
0
out
=
prepped_outputs
for
_
in
range
(
no_chunks
):
# Chunk the input
if
(
not
low_mem
):
select_chunk
=
(
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
)
else
:
select_chunk
=
(
partial
(
_chunk_slice
,
flat_start
=
i
,
flat_end
=
min
(
flat_batch_dim
,
i
+
chunk_size
),
no_batch_dims
=
len
(
orig_batch_dims
)
)
)
chunks
=
tensor_tree_map
(
select_chunk
,
prepped_inputs
)
# Run the layer on the chunk
output_chunk
=
layer
(
**
chunks
)
# Allocate space for the output
if
out
is
None
:
allocate
=
lambda
t
:
t
.
new_zeros
((
flat_batch_dim
,)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
out_type
=
type
(
output_chunk
)
if
out_type
is
dict
:
def
assign
(
d1
,
d2
):
for
k
,
v
in
d1
.
items
():
if
type
(
v
)
is
dict
:
assign
(
v
,
d2
[
k
])
else
:
if
(
_add_into_out
):
v
[
i
:
i
+
chunk_size
]
+=
d2
[
k
]
else
:
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
assign
(
out
,
output_chunk
)
elif
out_type
is
tuple
:
for
x1
,
x2
in
zip
(
out
,
output_chunk
):
if
(
_add_into_out
):
x1
[
i
:
i
+
chunk_size
]
+=
x2
else
:
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
out_type
is
torch
.
Tensor
:
if
(
_add_into_out
):
out
[
i
:
i
+
chunk_size
]
+=
output_chunk
else
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
else
:
raise
ValueError
(
"Not supported"
)
i
+=
chunk_size
reshape
=
lambda
t
:
t
.
view
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
class
ChunkSizeTuner
:
def
__init__
(
self
,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size
=
512
,
):
self
.
max_chunk_size
=
max_chunk_size
self
.
cached_chunk_size
=
None
self
.
cached_arg_data
=
None
def
_determine_favorable_chunk_size
(
self
,
fn
,
args
,
min_chunk_size
):
logging
.
info
(
"Tuning chunk size..."
)
if
(
min_chunk_size
>=
self
.
max_chunk_size
):
return
min_chunk_size
candidates
=
[
2
**
l
for
l
in
range
(
int
(
math
.
log
(
self
.
max_chunk_size
,
2
))
+
1
)]
candidates
=
[
c
for
c
in
candidates
if
c
>
min_chunk_size
]
candidates
=
[
min_chunk_size
]
+
candidates
candidates
[
-
1
]
+=
4
def
test_chunk_size
(
chunk_size
):
try
:
with
torch
.
no_grad
():
fn
(
*
args
,
chunk_size
=
chunk_size
)
return
True
except
RuntimeError
:
return
False
min_viable_chunk_size_index
=
0
i
=
len
(
candidates
)
-
1
while
i
>
min_viable_chunk_size_index
:
viable
=
test_chunk_size
(
candidates
[
i
])
if
(
not
viable
):
i
=
(
min_viable_chunk_size_index
+
i
)
//
2
else
:
min_viable_chunk_size_index
=
i
i
=
(
i
+
len
(
candidates
)
-
1
)
//
2
return
candidates
[
min_viable_chunk_size_index
]
def
_compare_arg_caches
(
self
,
ac1
,
ac2
):
consistent
=
True
for
a1
,
a2
in
zip
(
ac1
,
ac2
):
assert
(
type
(
ac1
)
==
type
(
ac2
))
if
(
type
(
ac1
)
is
list
or
type
(
ac1
)
is
tuple
):
consistent
&=
self
.
_compare_arg_caches
(
a1
,
a2
)
elif
(
type
(
ac1
)
is
dict
):
a1_items
=
[
v
for
_
,
v
in
sorted
(
a1
.
items
(),
key
=
lambda
x
:
x
[
0
])
]
a2_items
=
[
v
for
_
,
v
in
sorted
(
a2
.
items
(),
key
=
lambda
x
:
x
[
0
])
]
consistent
&=
self
.
_compare_arg_caches
(
a1_items
,
a2_items
)
else
:
consistent
&=
a1
==
a2
return
consistent
def
tune_chunk_size
(
self
,
representative_fn
:
Callable
,
args
:
Tuple
[
Any
],
min_chunk_size
:
int
,
)
->
int
:
consistent
=
True
remove_tensors
=
lambda
a
:
a
.
shape
if
type
(
a
)
is
torch
.
Tensor
else
a
arg_data
=
tree_map
(
remove_tensors
,
args
,
object
)
if
(
self
.
cached_arg_data
is
not
None
):
# If args have changed shape/value, we need to re-tune
assert
(
len
(
self
.
cached_arg_data
)
==
len
(
arg_data
))
consistent
=
self
.
_compare_arg_caches
(
self
.
cached_arg_data
,
arg_data
)
else
:
# Otherwise, we can reuse the precomputed value
consistent
=
False
if
(
not
consistent
):
self
.
cached_chunk_size
=
self
.
_determine_favorable_chunk_size
(
representative_fn
,
args
,
min_chunk_size
,
)
self
.
cached_arg_data
=
arg_data
return
self
.
cached_chunk_size
openfold/utils/exponential_moving_average.py
View file @
39a6d0e6
...
...
@@ -58,7 +58,8 @@ class ExponentialMovingAverage:
self
.
_update_state_dict_
(
model
.
state_dict
(),
self
.
params
)
def
load_state_dict
(
self
,
state_dict
:
OrderedDict
)
->
None
:
self
.
params
=
state_dict
[
"params"
]
for
k
in
state_dict
[
"params"
].
keys
():
self
.
params
[
k
]
=
state_dict
[
"params"
][
k
].
clone
()
self
.
decay
=
state_dict
[
"decay"
]
def
state_dict
(
self
)
->
OrderedDict
:
...
...
openfold/utils/feats.py
View file @
39a6d0e6
...
...
@@ -22,7 +22,7 @@ from typing import Dict, Union
from
openfold.np
import
protein
import
openfold.np.residue_constants
as
rc
from
openfold.utils.geometry
import
rigid_matrix_vector
,
rotation_matrix
from
openfold.utils.geometry
import
rigid_matrix_vector
,
rotation_matrix
,
vector
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
batched_gather
,
...
...
@@ -188,13 +188,16 @@ def torsion_angles_to_frames(
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
):
rigid_type
=
Rigid
if
isinstance
(
r
,
Rigid
)
else
rigid_matrix_vector
.
Rigid3Array
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
aatype
,
...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_r
=
r
.
from_tensor_4x4
(
default_4x4
)
default_r
=
r
igid_type
.
from_tensor_4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
...
...
@@ -221,11 +224,9 @@ def torsion_angles_to_frames(
all_rots
[...,
2
,
1
:]
=
alpha
if
isinstance
(
r
,
Rigid
):
rigid_type
=
Rigid
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
),
None
)
all_frames
=
default_r
.
compose
(
all_rots
)
else
:
rigid_type
=
rigid_matrix_vector
.
Rigid3Array
all_rots
=
rotation_matrix
.
Rot3Array
.
from_array
(
all_rots
)
all_frames
=
default_r
.
compose_rotation
(
all_rots
)
...
...
@@ -291,4 +292,7 @@ def frames_and_literature_positions_to_atom14_pos(
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
pred_positions
*
atom_mask
if
isinstance
(
pred_positions
,
vector
.
Vec3Array
):
return
pred_positions
.
to_tensor
()
return
pred_positions
openfold/utils/geometry/quat_rigid.py
View file @
39a6d0e6
...
...
@@ -16,7 +16,7 @@ class QuatRigid(nn.Module):
else
:
rigid_dim
=
6
self
.
linear
=
Linear
(
c_hidden
,
rigid_dim
)
self
.
linear
=
Linear
(
c_hidden
,
rigid_dim
,
init
=
"final"
)
def
forward
(
self
,
activations
:
torch
.
Tensor
)
->
Rigid3Array
:
# NOTE: During training, this needs to be run in higher precision
...
...
openfold/utils/import_weights.py
View file @
39a6d0e6
...
...
@@ -59,14 +59,14 @@ class Param:
stacked
:
bool
=
False
def
_
process_translation
s
_dict
(
d
,
top_layer
=
True
):
def
process_translation_dict
(
d
,
top_layer
=
True
):
flat
=
{}
for
k
,
v
in
d
.
items
():
if
type
(
v
)
==
dict
:
prefix
=
_NPZ_KEY_PREFIX
if
top_layer
else
""
sub_flat
=
{
(
prefix
+
"/"
.
join
([
k
,
k_prime
])):
v_prime
for
k_prime
,
v_prime
in
_
process_translation
s
_dict
(
for
k_prime
,
v_prime
in
process_translation_dict
(
v
,
top_layer
=
False
).
items
()
}
...
...
@@ -129,7 +129,7 @@ def assign(translation_dict, orig_weights):
raise
def
ge
t
_translation_dict
(
model
,
version
,
is_multimer
=
False
):
def
ge
nerate
_translation_dict
(
model
,
version
,
is_multimer
=
False
):
#######################
# Some templates
#######################
...
...
@@ -277,7 +277,7 @@ def get_translation_dict(model, version, is_multimer=False):
},
"v_scalar_projection"
:
{
"weights"
:
LinearWeightMultimer
(
ipa
.
linear_
k
.
weight
,
ipa
.
linear_
v
.
weight
,
),
},
"q_point_projection"
:
PointProjectionParams
(
...
...
@@ -388,11 +388,6 @@ def get_translation_dict(model, version, is_multimer=False):
############################
# translations dict overflow
############################
tps_blocks
=
model
.
template_embedder
.
template_pair_stack
.
blocks
tps_blocks_params
=
stacked
(
[
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
]
)
ems_blocks
=
model
.
extra_msa_stack
.
blocks
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
...
...
@@ -416,32 +411,10 @@ def get_translation_dict(model, version, is_multimer=False):
"pair_activiations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
"template_embedding"
:
{
"single_template_embedding"
:
{
"embedding2d"
:
LinearParams
(
model
.
template_embedder
.
template_pair_embedder
.
linear
),
"template_pair_stack"
:
{
"__layer_stack_no_state"
:
tps_blocks_params
,
},
"output_layer_norm"
:
LayerNormParams
(
model
.
template_embedder
.
template_pair_stack
.
layer_norm
),
},
"attention"
:
AttentionParams
(
model
.
template_embedder
.
template_pointwise_att
.
mha
),
},
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_stack"
:
ems_blocks_params
,
"template_single_embedding"
:
LinearParams
(
model
.
template_embedder
.
template_angle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_embedder
.
template_angle_embedder
.
linear_2
),
"evoformer_iteration"
:
evo_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
},
...
...
@@ -478,7 +451,6 @@ def get_translation_dict(model, version, is_multimer=False):
},
}
else
:
temp_embedder
=
model
.
template_embedder
translations
=
{
"evoformer"
:
{
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
...
...
@@ -497,53 +469,6 @@ def get_translation_dict(model, version, is_multimer=False):
model
.
input_embedder
.
linear_relpos
),
},
"template_embedding"
:
{
"single_template_embedding"
:
{
"query_embedding_norm"
:
LayerNormParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_layer_norm
),
"template_pair_embedding_0"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
dgram_linear
),
"template_pair_embedding_1"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
pseudo_beta_mask_linear
),
"template_pair_embedding_2"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_1
),
"template_pair_embedding_3"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_2
),
"template_pair_embedding_4"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
x_linear
),
"template_pair_embedding_5"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
y_linear
),
"template_pair_embedding_6"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
z_linear
),
"template_pair_embedding_7"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
backbone_mask_linear
),
"template_pair_embedding_8"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_linear
),
"template_embedding_iteration"
:
tps_blocks_params
,
"output_layer_norm"
:
LayerNormParams
(
model
.
template_embedder
.
template_pair_stack
.
layer_norm
),
},
"output_linear"
:
LinearParams
(
temp_embedder
.
linear_t
),
},
"template_projection"
:
LinearParams
(
temp_embedder
.
template_single_embedder
.
template_projector
,
),
"template_single_embedding"
:
LinearParams
(
temp_embedder
.
template_single_embedder
.
template_single_embedder
,
),
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
...
...
@@ -592,12 +517,88 @@ def get_translation_dict(model, version, is_multimer=False):
"model_4_ptm"
,
"model_5_ptm"
,
]
if
version
in
no_templ
:
evo_dict
=
translations
[
"evoformer"
]
keys
=
list
(
evo_dict
.
keys
())
for
k
in
keys
:
if
"template_"
in
k
:
evo_dict
.
pop
(
k
)
if
version
not
in
no_templ
:
tps_blocks
=
model
.
template_embedder
.
template_pair_stack
.
blocks
tps_blocks_params
=
stacked
(
[
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
]
)
if
(
not
is_multimer
):
template_param_dict
=
{
"template_embedding"
:
{
"single_template_embedding"
:
{
"embedding2d"
:
LinearParams
(
model
.
template_embedder
.
template_pair_embedder
.
linear
),
"template_pair_stack"
:
{
"__layer_stack_no_state"
:
tps_blocks_params
,
},
"output_layer_norm"
:
LayerNormParams
(
model
.
template_embedder
.
template_pair_stack
.
layer_norm
),
},
"attention"
:
AttentionParams
(
model
.
template_embedder
.
template_pointwise_att
.
mha
),
},
"template_single_embedding"
:
LinearParams
(
model
.
template_embedder
.
template_angle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_embedder
.
template_angle_embedder
.
linear_2
),
}
else
:
temp_embedder
=
model
.
template_embedder
template_param_dict
=
{
"template_embedding"
:
{
"single_template_embedding"
:
{
"query_embedding_norm"
:
LayerNormParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_layer_norm
),
"template_pair_embedding_0"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
dgram_linear
),
"template_pair_embedding_1"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
pseudo_beta_mask_linear
),
"template_pair_embedding_2"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_1
),
"template_pair_embedding_3"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_2
),
"template_pair_embedding_4"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
x_linear
),
"template_pair_embedding_5"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
y_linear
),
"template_pair_embedding_6"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
z_linear
),
"template_pair_embedding_7"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
backbone_mask_linear
),
"template_pair_embedding_8"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_linear
),
"template_embedding_iteration"
:
tps_blocks_params
,
"output_layer_norm"
:
LayerNormParams
(
model
.
template_embedder
.
template_pair_stack
.
layer_norm
),
},
"output_linear"
:
LinearParams
(
temp_embedder
.
linear_t
),
},
"template_projection"
:
LinearParams
(
temp_embedder
.
template_single_embedder
.
template_projector
,
),
"template_single_embedding"
:
LinearParams
(
temp_embedder
.
template_single_embedder
.
template_single_embedder
,
),
}
translations
[
"evoformer"
].
update
(
template_param_dict
)
if
"_ptm"
in
version
:
translations
[
"predicted_aligned_error_head"
]
=
{
...
...
@@ -609,15 +610,10 @@ def get_translation_dict(model, version, is_multimer=False):
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
translations
=
get_translation_dict
(
model
,
version
,
is_multimer
=
(
"multimer"
in
version
)
)
translations
=
generate_translation_dict
(
model
,
version
,
is_multimer
=
(
"multimer"
in
version
))
# Flatten keys and insert missing key prefixes
flat
=
_
process_translation
s
_dict
(
translations
)
flat
=
process_translation_dict
(
translations
)
# Sanity check
keys
=
list
(
data
.
keys
())
...
...
openfold/utils/kernel/csrc/softmax_cuda_stub.cpp
0 → 100644
View file @
39a6d0e6
// Copyright 2021 AlQuraishi Laboratory
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
#include <torch/extension.h>
void
attn_softmax_inplace_forward_
(
at
::
Tensor
input
,
long
long
rows
,
int
cols
)
{
throw
std
::
runtime_error
(
"attn_softmax_inplace_forward_ not implemented on CPU"
);
};
void
attn_softmax_inplace_backward_
(
at
::
Tensor
output
,
at
::
Tensor
d_ov
,
at
::
Tensor
values
,
long
long
rows
,
int
cols_output
,
int
cols_values
)
{
throw
std
::
runtime_error
(
"attn_softmax_inplace_backward_ not implemented on CPU"
);
};
\ No newline at end of file
openfold/utils/logger.py
View file @
39a6d0e6
...
...
@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
operator
import
time
import
dllogger
as
logger
import
numpy
as
np
import
torch.cuda.profiler
as
profiler
from
dllogger
import
JSONStreamBackend
,
StdOutBackend
,
Verbosity
import
numpy
as
np
from
pytorch_lightning
import
Callback
import
torch.cuda.profiler
as
profiler
def
is_main_process
():
...
...
openfold/utils/loss.py
View file @
39a6d0e6
...
...
@@ -43,9 +43,15 @@ def softmax_cross_entropy(logits, labels):
def
sigmoid_cross_entropy
(
logits
,
labels
):
log_p
=
torch
.
log
(
torch
.
sigmoid
(
logits
))
log_not_p
=
torch
.
log
(
torch
.
sigmoid
(
-
logits
))
loss
=
-
labels
*
log_p
-
(
1
-
labels
)
*
log_not_p
logits_dtype
=
logits
.
dtype
logits
=
logits
.
double
()
labels
=
labels
.
double
()
log_p
=
torch
.
nn
.
functional
.
logsigmoid
(
logits
)
# log_p = torch.log(torch.sigmoid(logits))
log_not_p
=
torch
.
nn
.
functional
.
logsigmoid
(
-
1
*
logits
)
# log_not_p = torch.log(torch.sigmoid(-logits))
loss
=
(
-
1.
*
labels
)
*
log_p
-
(
1.
-
labels
)
*
log_not_p
loss
=
loss
.
to
(
dtype
=
logits_dtype
)
return
loss
...
...
@@ -658,10 +664,11 @@ def compute_tm(
denom
=
eps
+
torch
.
sum
(
pair_residue_weights
,
dim
=-
1
,
keepdims
=
True
)
normed_residue_mask
=
pair_residue_weights
/
denom
per_alignment
=
torch
.
sum
(
predicted_tm_term
*
normed_residue_mask
,
dim
=-
1
)
weighted
=
per_alignment
*
residue_weights
idx
=
weighted
.
argmax
(
dim
=-
1
,
keepdim
=
True
)
return
torch
.
gather
(
per_alignment
,
-
1
,
idx
).
squeeze
(
-
1
)
argmax
=
(
weighted
==
torch
.
max
(
weighted
)).
nonzero
()[
0
]
return
per_alignment
[
tuple
(
argmax
)]
def
tm_loss
(
logits
,
...
...
@@ -1483,17 +1490,17 @@ def experimentally_resolved_loss(
loss
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=-
1
)
loss
=
loss
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)))
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
loss
=
torch
.
mean
(
loss
)
return
loss
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
eps
=
1e-8
,
**
kwargs
):
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
num_classes
,
eps
=
1e-8
,
**
kwargs
):
"""
Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
...
...
@@ -1505,7 +1512,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
Masked MSA loss
"""
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
num_classes
)
)
# FP16-friendly averaging. Equivalent to:
...
...
@@ -1562,10 +1569,10 @@ class AlphaFoldLoss(nn.Module):
batch
,
self
.
config
.
fape
,
),
"lddt"
:
lambda
:
lddt_loss
(
"
p
lddt
_loss
"
:
lambda
:
lddt_loss
(
logits
=
out
[
"lddt_logits"
],
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
lddt
},
**
{
**
batch
,
**
self
.
config
.
p
lddt
_loss
},
),
"masked_msa"
:
lambda
:
masked_msa_loss
(
logits
=
out
[
"masked_msa_logits"
],
...
...
openfold/utils/precision_utils.py
0 → 100644
View file @
39a6d0e6
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
import
torch
def
is_fp16_enabled
():
# Autocast world
fp16_enabled
=
torch
.
get_autocast_gpu_dtype
()
==
torch
.
float16
fp16_enabled
=
fp16_enabled
and
torch
.
is_autocast_enabled
()
return
fp16_enabled
openfold/utils/rigid_utils.py
View file @
39a6d0e6
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
from
__future__
import
annotations
from
functools
import
lru_cache
from
typing
import
Tuple
,
Any
,
Sequence
,
Callable
,
Optional
import
numpy
as
np
...
...
@@ -34,51 +35,31 @@ def rot_matmul(
Returns:
The product ab
"""
row_1
=
torch
.
stack
(
[
a
[...,
0
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
0
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
0
,
2
]
*
b
[...,
2
,
0
],
a
[...,
0
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
0
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
0
,
2
]
*
b
[...,
2
,
1
],
a
[...,
0
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
0
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
0
,
2
]
*
b
[...,
2
,
2
],
],
dim
=-
1
,
)
row_2
=
torch
.
stack
(
[
a
[...,
1
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
0
],
a
[...,
1
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
1
],
a
[...,
1
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
2
],
],
dim
=-
1
,
)
row_3
=
torch
.
stack
(
def
row_mul
(
i
):
return
torch
.
stack
(
[
a
[...,
i
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
i
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
i
,
2
]
*
b
[...,
2
,
0
],
a
[...,
i
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
i
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
i
,
2
]
*
b
[...,
2
,
1
],
a
[...,
i
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
i
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
i
,
2
]
*
b
[...,
2
,
2
],
],
dim
=-
1
,
)
return
torch
.
stack
(
[
a
[...,
2
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
0
],
a
[...,
2
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
1
],
a
[...,
2
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
2
],
],
dim
=-
1
,
row_mul
(
0
),
row_mul
(
1
),
row_mul
(
2
),
],
dim
=-
2
)
return
torch
.
stack
([
row_1
,
row_2
,
row_3
],
dim
=-
2
)
def
rot_vec_mul
(
r
:
torch
.
Tensor
,
...
...
@@ -94,9 +75,7 @@ def rot_vec_mul(
Returns:
[*, 3] rotated coordinates
"""
x
=
t
[...,
0
]
y
=
t
[...,
1
]
z
=
t
[...,
2
]
x
,
y
,
z
=
torch
.
unbind
(
t
,
dim
=-
1
)
return
torch
.
stack
(
[
r
[...,
0
,
0
]
*
x
+
r
[...,
0
,
1
]
*
y
+
r
[...,
0
,
2
]
*
z
,
...
...
@@ -106,7 +85,7 @@ def rot_vec_mul(
dim
=-
1
,
)
@
lru_cache
(
maxsize
=
None
)
def
identity_rot_mats
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
...
...
@@ -118,10 +97,12 @@ def identity_rot_mats(
)
rots
=
rots
.
view
(
*
((
1
,)
*
len
(
batch_dims
)),
3
,
3
)
rots
=
rots
.
expand
(
*
batch_dims
,
-
1
,
-
1
)
rots
=
rots
.
contiguous
()
return
rots
@
lru_cache
(
maxsize
=
None
)
def
identity_trans
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
...
...
@@ -137,6 +118,7 @@ def identity_trans(
return
trans
@
lru_cache
(
maxsize
=
None
)
def
identity_quats
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
...
...
@@ -196,7 +178,7 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
# [4, 4, 3, 3]
mat
=
quat
.
new_tensor
(
_QTR_MAT
,
requires_grad
=
Fals
e
)
mat
=
_get_quat
(
"
_QTR_MAT
"
,
dtype
=
quat
.
dtype
,
device
=
quat
.
devic
e
)
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
mat
.
shape
)
...
...
@@ -251,10 +233,20 @@ _QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
_QUAT_MULTIPLY_BY_VEC
=
_QUAT_MULTIPLY
[:,
1
:,
:]
_CACHED_QUATS
=
{
"_QTR_MAT"
:
_QTR_MAT
,
"_QUAT_MULTIPLY"
:
_QUAT_MULTIPLY
,
"_QUAT_MULTIPLY_BY_VEC"
:
_QUAT_MULTIPLY_BY_VEC
}
@
lru_cache
(
maxsize
=
None
)
def
_get_quat
(
quat_key
,
dtype
,
device
):
return
torch
.
tensor
(
_CACHED_QUATS
[
quat_key
],
dtype
=
dtype
,
device
=
device
)
def
quat_multiply
(
quat1
,
quat2
):
"""Multiply a quaternion by another quaternion."""
mat
=
quat1
.
new_tensor
(
_QUAT_MULTIPLY
)
mat
=
_get_quat
(
"_QUAT_MULTIPLY"
,
dtype
=
quat1
.
dtype
,
device
=
quat1
.
device
)
reshaped_mat
=
mat
.
view
((
1
,)
*
len
(
quat1
.
shape
[:
-
1
])
+
mat
.
shape
)
return
torch
.
sum
(
reshaped_mat
*
...
...
@@ -266,7 +258,7 @@ def quat_multiply(quat1, quat2):
def
quat_multiply_by_vec
(
quat
,
vec
):
"""Multiply a quaternion by a pure-vector quaternion."""
mat
=
quat
.
new_tensor
(
_QUAT_MULTIPLY_BY_VEC
)
mat
=
_get_quat
(
"
_QUAT_MULTIPLY_BY_VEC
"
,
dtype
=
quat
.
dtype
,
device
=
quat
.
device
)
reshaped_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
1
])
+
mat
.
shape
)
return
torch
.
sum
(
reshaped_mat
*
...
...
openfold/utils/script_utils.py
0 → 100644
View file @
39a6d0e6
import
json
import
logging
import
os
import
re
import
time
import
numpy
import
torch
from
openfold.model.model
import
AlphaFold
from
openfold.np
import
residue_constants
,
protein
from
openfold.np.relax
import
relax
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
)
from
pytorch_lightning.utilities.deepspeed
import
(
convert_zero_checkpoint_to_fp32_state_dict
)
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
def
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
):
model_count
=
0
if
openfold_checkpoint_path
:
model_count
+=
len
(
openfold_checkpoint_path
.
split
(
","
))
if
jax_param_path
:
model_count
+=
len
(
jax_param_path
.
split
(
","
))
return
model_count
def
get_model_basename
(
model_path
):
return
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
model_path
)
)
)[
0
]
def
make_output_directory
(
output_dir
,
model_name
,
multiple_model_mode
):
if
multiple_model_mode
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
,
model_name
)
else
:
prediction_dir
=
os
.
path
.
join
(
output_dir
,
"predictions"
)
os
.
makedirs
(
prediction_dir
,
exist_ok
=
True
)
return
prediction_dir
def
load_models_from_command_line
(
config
,
model_device
,
openfold_checkpoint_path
,
jax_param_path
,
output_dir
):
# Create the output directory
multiple_model_mode
=
count_models_to_evaluate
(
openfold_checkpoint_path
,
jax_param_path
)
>
1
if
multiple_model_mode
:
logger
.
info
(
f
"evaluating multiple models"
)
if
jax_param_path
:
for
path
in
jax_param_path
.
split
(
","
):
model_basename
=
get_model_basename
(
path
)
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
path
,
version
=
model_version
)
model
=
model
.
to
(
model_device
)
logger
.
info
(
f
"Successfully loaded JAX parameters at
{
path
}
..."
)
output_directory
=
make_output_directory
(
output_dir
,
model_basename
,
multiple_model_mode
)
yield
model
,
output_directory
if
openfold_checkpoint_path
:
for
path
in
openfold_checkpoint_path
.
split
(
","
):
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
checkpoint_basename
=
get_model_basename
(
path
)
if
os
.
path
.
isdir
(
path
):
# A DeepSpeed checkpoint
ckpt_path
=
os
.
path
.
join
(
output_dir
,
checkpoint_basename
+
".pt"
,
)
if
not
os
.
path
.
isfile
(
ckpt_path
):
convert_zero_checkpoint_to_fp32_state_dict
(
path
,
ckpt_path
,
)
d
=
torch
.
load
(
ckpt_path
)
model
.
load_state_dict
(
d
[
"ema"
][
"params"
])
else
:
ckpt_path
=
path
d
=
torch
.
load
(
ckpt_path
)
if
"ema"
in
d
:
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
model
.
load_state_dict
(
d
)
model
=
model
.
to
(
model_device
)
logger
.
info
(
f
"Loaded OpenFold parameters at
{
path
}
..."
)
output_directory
=
make_output_directory
(
output_dir
,
checkpoint_basename
,
multiple_model_mode
)
yield
model
,
output_directory
if
not
jax_param_path
and
not
openfold_checkpoint_path
:
raise
ValueError
(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
def
parse_fasta
(
data
):
data
=
re
.
sub
(
'>$'
,
''
,
data
,
flags
=
re
.
M
)
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
return
tags
,
seqs
def
update_timings
(
timing_dict
,
output_file
=
os
.
path
.
join
(
os
.
getcwd
(),
"timings.json"
)):
"""
Write dictionary of one or more run step times to a file
"""
if
os
.
path
.
exists
(
output_file
):
with
open
(
output_file
,
"r"
)
as
f
:
try
:
timings
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
logger
.
info
(
f
"Overwriting non-standard JSON in
{
output_file
}
."
)
timings
=
{}
else
:
timings
=
{}
timings
.
update
(
timing_dict
)
with
open
(
output_file
,
"w"
)
as
f
:
json
.
dump
(
timings
,
f
)
return
output_file
def
run_model
(
model
,
batch
,
tag
,
output_dir
):
with
torch
.
no_grad
():
# Temporarily disable templates if there aren't any in the batch
template_enabled
=
model
.
config
.
template
.
enabled
model
.
config
.
template
.
enabled
=
template_enabled
and
any
([
"template_"
in
k
for
k
in
batch
])
logger
.
info
(
f
"Running inference for
{
tag
}
..."
)
t
=
time
.
perf_counter
()
out
=
model
(
batch
)
inference_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Inference time:
{
inference_time
}
"
)
update_timings
({
"inference"
:
inference_time
},
os
.
path
.
join
(
output_dir
,
"timings.json"
))
model
.
config
.
template
.
enabled
=
template_enabled
return
out
def
prep_output
(
out
,
batch
,
feature_dict
,
feature_processor
,
config_preset
,
multimer_ri_gap
,
subtract_plddt
):
plddt
=
out
[
"plddt"
]
plddt_b_factors
=
numpy
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
)
if
subtract_plddt
:
plddt_b_factors
=
100
-
plddt_b_factors
# Prep protein metadata
template_domain_names
=
[]
template_chain_index
=
None
if
feature_processor
.
config
.
common
.
use_templates
and
"template_domain_names"
in
feature_dict
:
template_domain_names
=
[
t
.
decode
(
"utf-8"
)
for
t
in
feature_dict
[
"template_domain_names"
]
]
# This works because templates are not shuffled during inference
template_domain_names
=
template_domain_names
[
:
feature_processor
.
config
.
predict
.
max_templates
]
if
"template_chain_index"
in
feature_dict
:
template_chain_index
=
feature_dict
[
"template_chain_index"
]
template_chain_index
=
template_chain_index
[
:
feature_processor
.
config
.
predict
.
max_templates
]
no_recycling
=
feature_processor
.
config
.
common
.
max_recycling_iters
remark
=
', '
.
join
([
f
"no_recycling=
{
no_recycling
}
"
,
f
"max_templates=
{
feature_processor
.
config
.
predict
.
max_templates
}
"
,
f
"config_preset=
{
config_preset
}
"
,
])
# For multi-chain FASTAs
ri
=
feature_dict
[
"residue_index"
]
chain_index
=
(
ri
-
numpy
.
arange
(
ri
.
shape
[
0
]))
/
multimer_ri_gap
chain_index
=
chain_index
.
astype
(
numpy
.
int64
)
cur_chain
=
0
prev_chain_max
=
0
for
i
,
c
in
enumerate
(
chain_index
):
if
c
!=
cur_chain
:
cur_chain
=
c
prev_chain_max
=
i
+
cur_chain
*
multimer_ri_gap
batch
[
"residue_index"
][
i
]
-=
prev_chain_max
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
,
remove_leading_feature_dimension
=
not
"multimer"
in
config_preset
,
remark
=
remark
,
parents
=
template_domain_names
,
parents_chain_index
=
template_chain_index
,
)
return
unrelaxed_protein
def
relax_protein
(
config
,
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
):
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
model_device
!=
"cpu"
),
**
config
.
relax
,
)
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
"cuda"
in
model_device
:
device_no
=
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
relaxation_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Relaxation time:
{
relaxation_time
}
"
)
update_timings
({
"relaxation"
:
relaxation_time
},
os
.
path
.
join
(
output_directory
,
"timings.json"
))
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
output_directory
,
f
'
{
output_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
relaxed_pdb_str
)
logger
.
info
(
f
"Relaxed output written to
{
relaxed_output_path
}
..."
)
\ No newline at end of file
openfold/utils/superimposition.py
View file @
39a6d0e6
...
...
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
Bio.SVDSuperimposer
import
SVDSuperimposer
import
numpy
as
np
import
torch
...
...
openfold/utils/tensor_utils.py
View file @
39a6d0e6
...
...
@@ -14,9 +14,22 @@
# limitations under the License.
from
functools
import
partial
import
logging
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
def
add
(
m1
,
m2
,
inplace
):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if
(
not
inplace
):
m1
=
m1
+
m2
else
:
m1
+=
m2
return
m1
def
permute_final_dims
(
tensor
:
torch
.
Tensor
,
inds
:
List
[
int
]):
...
...
@@ -106,303 +119,3 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
_fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
_fetch_dims
(
v
))
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
_fetch_dims
(
t
))
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
@
torch
.
jit
.
ignore
def
_flat_idx_to_idx
(
flat_idx
:
int
,
dims
:
Tuple
[
int
],
)
->
Tuple
[
int
]:
idx
=
[]
for
d
in
reversed
(
dims
):
idx
.
append
(
flat_idx
%
d
)
flat_idx
=
flat_idx
//
d
return
tuple
(
reversed
(
idx
))
@
torch
.
jit
.
ignore
def
_get_minimal_slice_set
(
start
:
Sequence
[
int
],
end
:
Sequence
[
int
],
dims
:
int
,
start_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
end_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
)
->
Sequence
[
Tuple
[
int
]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def
reduce_edge_list
(
l
):
tally
=
1
for
i
in
range
(
len
(
l
)):
reversed_idx
=
-
1
*
(
i
+
1
)
l
[
reversed_idx
]
*=
tally
tally
=
l
[
reversed_idx
]
if
(
start_edges
is
None
):
start_edges
=
[
s
==
0
for
s
in
start
]
reduce_edge_list
(
start_edges
)
if
(
end_edges
is
None
):
end_edges
=
[
e
==
(
d
-
1
)
for
e
,
d
in
zip
(
end
,
dims
)]
reduce_edge_list
(
end_edges
)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if
(
len
(
start
)
==
0
):
return
[
tuple
()]
elif
(
len
(
start
)
==
1
):
return
[(
slice
(
start
[
0
],
end
[
0
]
+
1
),)]
slices
=
[]
path
=
[]
# Dimensions common to start and end can be selected directly
for
s
,
e
in
zip
(
start
,
end
):
if
(
s
==
e
):
path
.
append
(
slice
(
s
,
s
+
1
))
else
:
break
path
=
tuple
(
path
)
divergence_idx
=
len
(
path
)
# start == end, and we're done
if
(
divergence_idx
==
len
(
dims
)):
return
[
tuple
(
path
)]
def
upper
():
sdi
=
start
[
divergence_idx
]
return
[
path
+
(
slice
(
sdi
,
sdi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
start
[
divergence_idx
+
1
:],
[
d
-
1
for
d
in
dims
[
divergence_idx
+
1
:]],
dims
[
divergence_idx
+
1
:],
start_edges
=
start_edges
[
divergence_idx
+
1
:],
end_edges
=
[
1
for
_
in
end_edges
[
divergence_idx
+
1
:]]
)
]
def
lower
():
edi
=
end
[
divergence_idx
]
return
[
path
+
(
slice
(
edi
,
edi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
[
0
for
_
in
start
[
divergence_idx
+
1
:]],
end
[
divergence_idx
+
1
:],
dims
[
divergence_idx
+
1
:],
start_edges
=
[
1
for
_
in
start_edges
[
divergence_idx
+
1
:]],
end_edges
=
end_edges
[
divergence_idx
+
1
:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if
(
start_edges
[
divergence_idx
]
and
end_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]
+
1
),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif
(
start_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
# Analogous to the previous case, but the top is ragged this time
elif
(
end_edges
[
divergence_idx
]):
slices
.
extend
(
upper
())
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]
+
1
),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else
:
slices
.
extend
(
upper
())
middle_ground
=
end
[
divergence_idx
]
-
start
[
divergence_idx
]
if
(
middle_ground
>
1
):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
return
[
tuple
(
s
)
for
s
in
slices
]
@
torch
.
jit
.
ignore
def
_chunk_slice
(
t
:
torch
.
Tensor
,
flat_start
:
int
,
flat_end
:
int
,
no_batch_dims
:
int
,
)
->
torch
.
Tensor
:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims
=
t
.
shape
[:
no_batch_dims
]
start_idx
=
list
(
_flat_idx_to_idx
(
flat_start
,
batch_dims
))
# _get_minimal_slice_set is inclusive
end_idx
=
list
(
_flat_idx_to_idx
(
flat_end
-
1
,
batch_dims
))
# Get an ordered list of slices to perform
slices
=
_get_minimal_slice_set
(
start_idx
,
end_idx
,
batch_dims
,
)
sliced_tensors
=
[
t
[
s
]
for
s
in
slices
]
return
torch
.
cat
(
[
s
.
view
((
-
1
,)
+
t
.
shape
[
no_batch_dims
:])
for
s
in
sliced_tensors
]
)
def
chunk_layer
(
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
no_batch_dims
:
int
,
low_mem
:
bool
=
False
,
)
->
Any
:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
_fetch_dims
(
inputs
)]
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
def
_prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
if
(
not
low_mem
):
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
else
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
return
t
prepped_inputs
=
tensor_tree_map
(
_prep_inputs
,
inputs
)
flat_batch_dim
=
1
for
d
in
orig_batch_dims
:
flat_batch_dim
*=
d
no_chunks
=
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
i
=
0
out
=
None
for
_
in
range
(
no_chunks
):
# Chunk the input
if
(
not
low_mem
):
select_chunk
=
(
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
)
else
:
select_chunk
=
(
partial
(
_chunk_slice
,
flat_start
=
i
,
flat_end
=
min
(
flat_batch_dim
,
i
+
chunk_size
),
no_batch_dims
=
len
(
orig_batch_dims
)
)
)
chunks
=
tensor_tree_map
(
select_chunk
,
prepped_inputs
)
# Run the layer on the chunk
output_chunk
=
layer
(
**
chunks
)
# Allocate space for the output
if
out
is
None
:
allocate
=
lambda
t
:
t
.
new_zeros
((
flat_batch_dim
,)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
out_type
=
type
(
output_chunk
)
if
out_type
is
dict
:
def
assign
(
d1
,
d2
):
for
k
,
v
in
d1
.
items
():
if
type
(
v
)
is
dict
:
assign
(
v
,
d2
[
k
])
else
:
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
assign
(
out
,
output_chunk
)
elif
out_type
is
tuple
:
for
x1
,
x2
in
zip
(
out
,
output_chunk
):
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
out_type
is
torch
.
Tensor
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
else
:
raise
ValueError
(
"Not supported"
)
i
+=
chunk_size
reshape
=
lambda
t
:
t
.
view
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
openfold/utils/trace_utils.py
0 → 100644
View file @
39a6d0e6
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextlib
from
functools
import
partialmethod
import
numpy
as
np
import
torch
from
openfold.utils.tensor_utils
import
tensor_tree_map
def
pad_feature_dict_seq
(
feature_dict
,
seqlen
):
""" Pads the sequence length of a feature dict. Used for tracing. """
# The real sequence length can't be longer than the desired one
true_n
=
feature_dict
[
"aatype"
].
shape
[
-
2
]
assert
(
true_n
<=
seqlen
)
new_feature_dict
=
{}
feat_seq_dims
=
{
"aatype"
:
-
2
,
"between_segment_residues"
:
-
1
,
"residue_index"
:
-
1
,
"seq_length"
:
-
1
,
"deletion_matrix_int"
:
-
1
,
"msa"
:
-
1
,
"num_alignments"
:
-
1
,
"template_aatype"
:
-
2
,
"template_all_atom_mask"
:
-
2
,
"template_all_atom_positions"
:
-
3
,
}
for
k
,
v
in
feature_dict
.
items
():
if
(
k
not
in
feat_seq_dims
):
new_feature_dict
[
k
]
=
v
continue
seq_dim
=
feat_seq_dims
[
k
]
padded_shape
=
list
(
v
.
shape
)
padded_shape
[
seq_dim
]
=
seqlen
new_value
=
np
.
zeros
(
padded_shape
,
dtype
=
v
.
dtype
)
new_value
[
tuple
(
slice
(
0
,
s
)
for
s
in
v
.
shape
)]
=
v
new_feature_dict
[
k
]
=
new_value
new_feature_dict
[
"seq_length"
][
0
]
=
seqlen
return
new_feature_dict
def
trace_model_
(
model
,
sample_input
):
# Grab the inputs to the final recycling iteration
feats
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
sample_input
)
# Gather some metadata
n
=
feats
[
"aatype"
].
shape
[
-
1
]
msa_depth
=
feats
[
"true_msa"
].
shape
[
-
2
]
extra_msa_depth
=
feats
[
"extra_msa"
].
shape
[
-
2
]
no_templates
=
feats
[
"template_aatype"
].
shape
[
-
2
]
device
=
feats
[
"aatype"
].
device
seq_mask
=
feats
[
"seq_mask"
].
to
(
device
)
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
extra_msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
device
)
template_pair_mask
=
torch
.
stack
([
pair_mask
]
*
no_templates
,
dim
=-
3
)
# Create some fake representations with the correct shapes
m
=
torch
.
rand
(
msa_depth
+
4
,
n
,
model
.
globals
.
c_m
).
to
(
device
)
z
=
torch
.
rand
(
n
,
n
,
model
.
globals
.
c_z
).
to
(
device
)
t
=
torch
.
rand
(
no_templates
,
n
,
n
,
model
.
globals
.
c_t
).
to
(
device
)
a
=
torch
.
rand
(
extra_msa_depth
,
n
,
model
.
globals
.
c_e
).
to
(
device
)
msa_mask
=
torch
.
randint
(
0
,
1
,
(
msa_depth
+
4
,
n
)).
to
(
device
)
# We need to do a dry run through the model so the chunk size tuners'
# trial runs (which run during the first-ever model iteration) aren't
# baked into the trace. There's no need to run the entire thing,
# though; we just need to run one block from each transformer stack.
evoformer_blocks
=
model
.
evoformer
.
blocks
model
.
evoformer
.
blocks
=
evoformer_blocks
[:
1
]
extra_msa_blocks
=
model
.
extra_msa_stack
.
blocks
model
.
extra_msa_stack
.
blocks
=
extra_msa_blocks
[:
1
]
if
(
model
.
template_config
.
enabled
):
template_pair_stack_blocks
=
model
.
template_pair_stack
.
blocks
model
.
template_pair_stack
.
blocks
=
template_pair_stack_blocks
[:
1
]
single_recycling_iter_input
=
tensor_tree_map
(
lambda
t
:
t
[...,
:
1
],
sample_input
,
)
with
torch
.
no_grad
():
_
=
model
(
single_recycling_iter_input
)
model
.
evoformer
.
blocks
=
evoformer_blocks
model
.
extra_msa_stack
.
blocks
=
extra_msa_blocks
del
evoformer_blocks
,
extra_msa_blocks
if
(
model
.
template_config
.
enabled
):
model
.
template_pair_stack
.
blocks
=
template_pair_stack_blocks
del
template_pair_stack_blocks
def
get_tuned_chunk_size
(
module
):
tuner
=
module
.
chunk_size_tuner
chunk_size
=
tuner
.
cached_chunk_size
# After our trial run above, this should always be set
assert
(
chunk_size
is
not
None
)
return
chunk_size
# Fetch the resulting chunk sizes
evoformer_chunk_size
=
model
.
globals
.
chunk_size
if
(
model
.
evoformer
.
chunk_size_tuner
is
not
None
):
evoformer_chunk_size
=
get_tuned_chunk_size
(
model
.
evoformer
)
extra_msa_chunk_size
=
model
.
globals
.
chunk_size
if
(
model
.
extra_msa_stack
.
chunk_size_tuner
is
not
None
):
extra_msa_chunk_size
=
get_tuned_chunk_size
(
model
.
extra_msa_stack
)
if
(
model
.
template_config
.
enabled
):
template_pair_stack_chunk_size
=
model
.
globals
.
chunk_size
if
(
model
.
template_pair_stack
.
chunk_size_tuner
is
not
None
):
template_pair_stack_chunk_size
=
get_tuned_chunk_size
(
model
.
template_pair_stack
)
def
trace_block
(
block
,
block_inputs
):
# Yes, yes, I know
with
contextlib
.
redirect_stderr
(
None
):
traced_block
=
torch
.
jit
.
trace
(
block
,
block_inputs
)
traced_block
=
torch
.
jit
.
freeze
(
traced_block
,
optimize_numerics
=
True
)
# It would be nice to use this, but its runtimes are extremely
# unpredictable
# traced_block = torch.jit.optimize_for_inference(traced_block)
# All trace inputs need to be tensors. This wrapper takes care of that
def
traced_block_wrapper
(
*
args
,
**
kwargs
):
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
)
if
type
(
t
)
!=
torch
.
Tensor
else
t
args
=
[
to_tensor
(
a
)
for
a
in
args
]
kwargs
=
{
k
:
to_tensor
(
v
)
for
k
,
v
in
kwargs
.
items
()}
return
traced_block
(
*
args
,
**
kwargs
)
return
traced_block_wrapper
def
verify_arg_order
(
fn
,
arg_list
):
""" Because it's difficult to specify keyword arguments of Module
functions during tracing, we need to pass them as a tuple. As a
sanity check, we manually verify their order here.
"""
fn_arg_names
=
fn
.
__code__
.
co_varnames
# Remove the "self" parameter
assert
(
fn_arg_names
[
0
]
==
"self"
)
fn_arg_names
=
fn_arg_names
[
1
:]
# Trim unspecified arguments
fn_arg_names
=
fn_arg_names
[:
len
(
arg_list
)]
name_tups
=
list
(
zip
(
fn_arg_names
,
[
n
for
n
,
_
in
arg_list
]))
assert
(
all
([
n1
==
n2
for
n1
,
n2
in
name_tups
]))
evoformer_attn_chunk_size
=
max
(
model
.
globals
.
chunk_size
,
evoformer_chunk_size
//
4
)
# MSA row attention
msa_att_row_arg_tuples
=
[
(
"m"
,
m
),
(
"z"
,
z
),
(
"mask"
,
msa_mask
),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
msa_att_row
.
forward
,
msa_att_row_arg_tuples
)
msa_att_row_args
=
[
arg
for
_
,
arg
in
msa_att_row_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
msa_att_row
,
msa_att_row_args
)
del
b
.
msa_att_row
b
.
msa_att_row
=
traced_block
# MSA col attention
msa_att_col_arg_tuples
=
[
(
"m"
,
m
),
(
"mask"
,
msa_mask
),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_chunk_size
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"use_flash"
,
torch
.
tensor
(
model
.
globals
.
use_flash
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
msa_att_col
.
forward
,
msa_att_col_arg_tuples
)
msa_att_col_args
=
[
arg
for
_
,
arg
in
msa_att_col_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
msa_att_col
,
msa_att_col_args
)
del
b
.
msa_att_col
b
.
msa_att_col
=
traced_block
# OPM
opm_arg_tuples
=
[
(
"m"
,
m
),
(
"mask"
,
msa_mask
.
float
()),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_chunk_size
)),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
core
.
outer_product_mean
.
forward
,
opm_arg_tuples
)
opm_args
=
[
arg
for
_
,
arg
in
opm_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
core
.
outer_product_mean
,
opm_args
)
del
b
.
core
.
outer_product_mean
b
.
core
.
outer_product_mean
=
traced_block
# Triangular multiplicative update (out)
tri_mul_out_arg_tuples
=
[
(
"z"
,
z
),
(
"mask"
,
pair_mask
.
float
()),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
(
"_add_with_inplace"
,
torch
.
tensor
(
True
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
.
forward
,
tri_mul_out_arg_tuples
)
tri_mul_out_args
=
[
arg
for
_
,
arg
in
tri_mul_out_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
core
.
tri_mul_out
,
tri_mul_out_args
)
del
b
.
core
.
tri_mul_out
b
.
core
.
tri_mul_out
=
traced_block
# Triangular multiplicative update (in)
tri_mul_in_arg_tuples
=
[
(
"z"
,
z
),
(
"mask"
,
pair_mask
.
float
()),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
(
"_add_with_inplace"
,
torch
.
tensor
(
True
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
.
forward
,
tri_mul_in_arg_tuples
)
tri_mul_in_args
=
[
arg
for
_
,
arg
in
tri_mul_in_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
core
.
tri_mul_in
,
tri_mul_in_args
)
del
b
.
core
.
tri_mul_in
b
.
core
.
tri_mul_in
=
traced_block
# Triangular attention (start)
tri_att_start_arg_tuples
=
[
(
"x"
,
z
),
(
"mask"
,
pair_mask
.
float
()),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
.
forward
,
tri_att_start_arg_tuples
)
tri_att_start_args
=
[
arg
for
_
,
arg
in
tri_att_start_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
core
.
tri_att_start
,
tri_att_start_args
)
del
b
.
core
.
tri_att_start
b
.
core
.
tri_att_start
=
traced_block
# Triangular attention (end)
tri_att_end_arg_tuples
=
[
(
"x"
,
z
.
transpose
(
-
2
,
-
3
)),
(
"mask"
,
pair_mask
.
transpose
(
-
1
,
-
2
).
float
()),
(
"chunk_size"
,
torch
.
tensor
(
evoformer_attn_chunk_size
)),
(
"use_memory_efficient_kernel"
,
torch
.
tensor
(
False
)),
(
"use_lma"
,
torch
.
tensor
(
model
.
globals
.
use_lma
)),
(
"inplace_safe"
,
torch
.
tensor
(
True
)),
]
verify_arg_order
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
.
forward
,
tri_att_end_arg_tuples
)
tri_att_end_args
=
[
arg
for
_
,
arg
in
tri_att_end_arg_tuples
]
with
torch
.
no_grad
():
for
b
in
model
.
evoformer
.
blocks
:
traced_block
=
trace_block
(
b
.
core
.
tri_att_end
,
tri_att_end_args
)
del
b
.
core
.
tri_att_end
b
.
core
.
tri_att_end
=
traced_block
#evoformer_arg_tuples = [
# ("m", m),
# ("z", z),
# ("msa_mask", msa_mask),
# ("pair_mask", pair_mask),
# ("chunk_size", torch.tensor(evoformer_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("use_flash", torch.tensor(model.globals.use_flash)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(evoformer_attn_chunk_size)),
#]
#verify_arg_order(model.evoformer.blocks[0].forward, evoformer_arg_tuples)
#evoformer_args = [arg for _, arg in evoformer_arg_tuples]
#with torch.no_grad():
# traced_evoformer_stack = []
# for b in model.evoformer.blocks:
# traced_block = trace_block(b, evoformer_args)
# traced_evoformer_stack.append(traced_block)
#del model.evoformer.blocks
#model.evoformer.blocks = traced_evoformer_stack
# with torch.no_grad():
# for b in model.evoformer.blocks:
# _ = b(*evoformer_args)
#
# with torch.no_grad():
# for b in model.evoformer.blocks:
# _ = b(*evoformer_args)
# extra_msa_attn_chunk_size = max(
# model.globals.chunk_size, extra_msa_chunk_size // 4
# )
# extra_msa_arg_tuples = [
# ("m", a),
# ("z", z),
# ("msa_mask", extra_msa_mask),
# ("pair_mask", pair_mask),
# ("chunk_size", torch.tensor(extra_msa_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(extra_msa_attn_chunk_size)),
# ]
# verify_arg_order(
# model.extra_msa_stack.blocks[0].forward, extra_msa_arg_tuples
# )
# extra_msa_args = [arg for _, arg in extra_msa_arg_tuples]
# with torch.no_grad():
# traced_extra_msa_stack = []
# for b in model.extra_msa_stack.blocks:
# traced_block = trace_block(b, extra_msa_args)
# traced_extra_msa_stack.append(traced_block)
#
# del model.extra_msa_stack.blocks
# model.extra_msa_stack.blocks = traced_extra_msa_stack
# if(model.template_config.enabled):
# template_pair_stack_attn_chunk_size = max(
# model.globals.chunk_size, template_pair_stack_chunk_size // 4
# )
# template_pair_stack_arg_tuples = [
# ("z", t),
# ("mask", template_pair_mask),
# ("chunk_size", torch.tensor(template_pair_stack_chunk_size)),
# ("use_lma", torch.tensor(model.globals.use_lma)),
# ("inplace_safe", torch.tensor(1)),
# ("_mask_trans", torch.tensor(model.config._mask_trans)),
# ("_attn_chunk_size", torch.tensor(
# template_pair_stack_attn_chunk_size
# )),
# ]
# verify_arg_order(
# model.template_pair_stack.blocks[0].forward,
# template_pair_stack_arg_tuples
# )
# template_pair_stack_args = [
# arg for _, arg in template_pair_stack_arg_tuples
# ]
#
# with torch.no_grad():
# traced_template_pair_stack = []
# for b in model.template_pair_stack.blocks:
# traced_block = trace_block(b, template_pair_stack_args)
# traced_template_pair_stack.append(traced_block)
#
# del model.template_pair_stack.blocks
# model.template_pair_stack.blocks = traced_template_pair_stack
# We need to do another dry run after tracing to allow the model to reach
# top speeds. Why, I don't know.
two_recycling_iter_input
=
tensor_tree_map
(
lambda
t
:
t
[...,
:
2
],
sample_input
,
)
with
torch
.
no_grad
():
_
=
model
(
two_recycling_iter_input
)
run_pretrained_openfold.py
View file @
39a6d0e6
...
...
@@ -12,49 +12,158 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
datetime
import
date
import
logging
import
math
import
numpy
as
np
import
os
from
openfold.utils.script_utils
import
load_models_from_command_line
,
parse_fasta
,
run_model
,
prep_output
,
\
update_timings
,
relax_protein
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
import
pickle
import
random
import
sys
import
time
import
torch
torch_versions
=
torch
.
__version__
.
split
(
"."
)
torch_major_version
=
int
(
torch_versions
[
0
])
torch_minor_version
=
int
(
torch_versions
[
1
])
if
(
torch_major_version
>
1
or
(
torch_major_version
==
1
and
torch_minor_version
>=
12
)
):
# Gives a large speedup on Ampere-class GPUs
torch
.
set_float32_matmul_precision
(
"high"
)
torch
.
set_grad_enabled
(
False
)
from
openfold.config
import
model_config
from
openfold.data
import
(
data_pipeline
,
feature_pipeline
,
templates
,
)
from
openfold.data.tools
import
hhsearch
,
hmmsearch
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.np
import
residue_constants
,
protein
import
openfold.np.relax.relax
as
relax
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
)
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
from
openfold.utils.trace_utils
import
(
pad_feature_dict_seq
,
trace_model_
,
)
from
scripts.utils
import
add_data_args
TRACING_INTERVAL
=
50
def
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
,
is_multimer
):
for
tag
,
seq
in
zip
(
tags
,
seqs
):
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
if
is_multimer
:
local_alignment_dir
=
alignment_dir
else
:
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
os
.
path
.
join
(
alignment_dir
,
tag
),
)
if
(
args
.
use_precomputed_alignments
is
None
and
not
os
.
path
.
isdir
(
local_alignment_dir
)):
logger
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
os
.
makedirs
(
local_alignment_dir
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
no_cpus
=
args
.
cpus
,
)
alignment_runner
.
run
(
tmp_fasta_path
,
local_alignment_dir
)
else
:
logger
.
info
(
f
"Using precomputed alignments for
{
tag
}
at
{
alignment_dir
}
..."
)
# Remove temporary FASTA file
os
.
remove
(
tmp_fasta_path
)
def
round_up_seqlen
(
seqlen
):
return
int
(
math
.
ceil
(
seqlen
/
TRACING_INTERVAL
))
*
TRACING_INTERVAL
def
generate_feature_dict
(
tags
,
seqs
,
alignment_dir
,
data_processor
,
args
,
):
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
if
len
(
seqs
)
==
1
:
tag
=
tags
[
0
]
seq
=
seqs
[
0
]
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
local_alignment_dir
)
elif
"multimer"
in
args
.
config_preset
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
alignment_dir
,
)
else
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
)
feature_dict
=
data_processor
.
process_multiseq_fasta
(
fasta_path
=
tmp_fasta_path
,
super_alignment_dir
=
alignment_dir
,
)
# Remove temporary FASTA file
os
.
remove
(
tmp_fasta_path
)
return
feature_dict
def
list_files_with_extensions
(
dir
,
extensions
):
return
[
f
for
f
in
os
.
listdir
(
dir
)
if
f
.
endswith
(
extensions
)]
def
main
(
args
):
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
args
.
param_path
,
version
=
args
.
model_name
)
#script_preset_(model)
model
=
model
.
to
(
args
.
model_device
)
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
config
=
model_config
(
args
.
config_preset
,
long_sequence_inference
=
args
.
long_sequence_inference
)
is_multimer
=
"multimer"
in
args
.
model_name
if
(
args
.
trace_model
):
if
(
not
config
.
data
.
predict
.
fixed_size
):
raise
ValueError
(
"Tracing requires that fixed_size mode be enabled in the config"
)
is_multimer
=
"multimer"
in
args
.
config_preset
if
(
is_multimer
):
if
(
not
args
.
use_precomputed_alignments
):
...
...
@@ -120,151 +229,150 @@ def main(args):
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
random_seed
=
random
.
randrange
(
2
**
32
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
np
.
random
.
seed
(
random_seed
)
torch
.
manual_seed
(
random_seed
+
1
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
if
(
not
args
.
use_precomputed_alignments
)
:
if
args
.
use_precomputed_alignments
is
None
:
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
else
:
alignment_dir
=
args
.
use_precomputed_alignments
for
fasta_path
in
os
.
listdir
(
args
.
fasta_dir
):
if
(
not
".fasta"
==
os
.
path
.
splitext
(
fasta_path
)[
-
1
]):
print
(
f
"Skipping
{
fasta_path
}
. Not a .fasta file..."
)
continue
fasta_path
=
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_path
)
tag_list
=
[]
seq_list
=
[]
for
fasta_file
in
list_files_with_extensions
(
args
.
fasta_dir
,
(
".fasta"
,
".fa"
)):
# Gather input sequences
fasta_path
=
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
)
with
open
(
fasta_path
,
"r"
)
as
fp
:
data
=
fp
.
read
()
tags
,
seqs
=
parse_fasta
(
data
)
lines
=
[
l
.
replace
(
'
\n
'
,
''
)
for
prot
in
data
.
split
(
'>'
)
for
l
in
prot
.
strip
().
split
(
'
\n
'
,
1
)
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
if
((
not
is_multimer
)
and
len
(
tags
)
!=
1
):
if
((
not
is_multimer
)
and
len
(
tags
)
!=
1
):
print
(
f
"
{
fasta_path
}
contains more than one sequence but "
f
"multimer mode is not enabled. Skipping..."
)
continue
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
tag_list
.
append
((
tag
,
tags
))
seq_list
.
append
(
seqs
)
seq_sort_fn
=
lambda
target
:
sum
([
len
(
s
)
for
s
in
target
[
1
]])
sorted_targets
=
sorted
(
zip
(
tag_list
,
seq_list
),
key
=
seq_sort_fn
)
feature_dicts
=
{}
model_generator
=
load_models_from_command_line
(
config
,
args
.
model_device
,
args
.
openfold_checkpoint_path
,
args
.
jax_param_path
,
args
.
output_dir
)
for
model
,
output_directory
in
model_generator
:
cur_tracing_interval
=
0
for
(
tag
,
tags
),
seqs
in
sorted_targets
:
output_name
=
f
'
{
tag
}
_
{
args
.
config_preset
}
'
if
args
.
output_postfix
is
not
None
:
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Does nothing if the alignments have already been computed
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
,
is_multimer
)
for
tag
,
seq
in
zip
(
tags
,
seqs
):
tag
,
seq
=
tags
[
0
],
seqs
[
0
]
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
):
if
not
os
.
path
.
exists
(
local_alignment_dir
):
os
.
makedirs
(
local_alignment_dir
)
alignment_runner
.
run
(
fasta_path
,
local_alignment_dir
feature_dict
=
feature_dicts
.
get
(
tag
,
None
)
if
(
feature_dict
is
None
):
feature_dict
=
generate_feature_dict
(
tags
,
seqs
,
alignment_dir
,
data_processor
,
args
,
)
if
(
is_multimer
):
local_alignment_dir
=
alignment_dir
else
:
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tags
[
0
],
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
fasta_path
,
alignment_dir
=
local_alignment_dir
)
if
(
args
.
trace_model
):
n
=
feature_dict
[
"aatype"
].
shape
[
-
2
]
rounded_seqlen
=
round_up_seqlen
(
n
)
feature_dict
=
pad_feature_dict_seq
(
feature_dict
,
rounded_seqlen
,
)
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
is_multimer
=
is_multimer
,
)
logging
.
info
(
"Executing model..."
)
batch
=
processed_feature_dict
with
torch
.
no_grad
():
batch
=
{
feature_dicts
[
tag
]
=
feature_dict
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
is_multimer
=
is_multimer
)
processed_feature_dict
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
for
k
,
v
in
batch
.
items
()
for
k
,
v
in
processed_feature_dict
.
items
()
}
t
=
time
.
perf_counter
()
chunk_size
=
model
.
globals
.
chunk_size
try
:
model
.
globals
.
chunk_size
=
None
out
=
model
(
batch
)
except
RuntimeError
as
e
:
model
.
globals
.
chunk_size
=
chunk_size
out
=
model
(
batch
)
logging
.
info
(
f
"Inference time:
{
time
.
perf_counter
()
-
t
}
"
)
if
(
args
.
trace_model
):
if
(
rounded_seqlen
>
cur_tracing_interval
):
logger
.
info
(
f
"Tracing model at
{
rounded_seqlen
}
residues..."
)
t
=
time
.
perf_counter
()
trace_model_
(
model
,
processed_feature_dict
)
tracing_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Tracing time:
{
tracing_time
}
"
)
cur_tracing_interval
=
rounded_seqlen
out
=
run_model
(
model
,
processed_feature_dict
,
tag
,
args
.
output_dir
)
# Toss out the recycling dimensions --- we don't need them anymore
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
plddt
=
out
[
"plddt"
]
mean_plddt
=
np
.
mean
(
plddt
)
plddt_b_factors
=
np
.
repeat
(
plddt
[...,
None
],
residue_constants
.
atom_type_num
,
axis
=-
1
processed_feature_dict
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
processed_feature_dict
)
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
,
remove_leading_feature_dimension
=
not
is_multimer
,
out
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
.
cpu
()),
out
)
unrelaxed_protein
=
prep_output
(
out
,
processed_feature_dict
,
feature_dict
,
feature_processor
,
args
.
config_preset
,
args
.
multimer_ri_gap
,
args
.
subtract_plddt
)
# Save the unrelaxed PDB.
unrelaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model
_name
}
_unrelaxed.pdb'
output_dir
ectory
,
f
'
{
output
_name
}
_unrelaxed.pdb'
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
print
(
unrelaxed_output_path
)
print
(
"asdjfh klasjdhf lkasjdhf lkjasdhflkjasdh fkl jasdhfklj hasdkljf hasldkjfh lkasjdfh lkajsdhflk asd"
)
with
open
(
unrelaxed_output_path
,
'w'
)
as
fp
:
fp
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
)
)
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
)
# Relax the prediction.
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
(
"cuda"
in
args
.
model_device
):
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
visible_devices
logging
.
info
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
logger
.
info
(
f
"Output written to
{
unrelaxed_output_path
}
..."
)
if
not
args
.
skip_relaxation
:
# Relax the prediction.
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
relax_protein
(
config
,
args
.
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
)
if
(
args
.
save_outputs
)
:
if
args
.
save_outputs
:
output_dict_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model
_name
}
_output_dict.pkl'
output_dir
ectory
,
f
'
{
output
_name
}
_output_dict.pkl'
)
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"fasta_dir"
,
type
=
str
,
help
=
"Path to directory containing FASTA files, one sequence per file"
)
parser
.
add_argument
(
"template_mmcif_dir"
,
type
=
str
,
...
...
@@ -284,18 +392,22 @@ if __name__ == "__main__":
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"model_1"
,
help
=
"""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
"--config_preset"
,
type
=
str
,
default
=
"model_1"
,
help
=
"""Name of a model config preset defined in openfold/config.py"""
)
parser
.
add_argument
(
"--jax_param_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser
.
add_argument
(
"--param_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params"""
"--openfold_checkpoint_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser
.
add_argument
(
"--save_outputs"
,
type
=
bool
,
default
=
False
,
"--save_outputs"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to save all model outputs, including embeddings, etc."
)
parser
.
add_argument
(
...
...
@@ -303,19 +415,45 @@ if __name__ == "__main__":
help
=
"""Number of CPUs with which to run alignment tools"""
)
parser
.
add_argument
(
'
--preset
'
,
type
=
str
,
default
=
'full_dbs'
,
"
--preset
"
,
type
=
str
,
default
=
'full_dbs'
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
)
parser
.
add_argument
(
'--data_random_seed'
,
type
=
str
,
default
=
None
"--output_postfix"
,
type
=
str
,
default
=
None
,
help
=
"""Postfix for output prediction filenames"""
)
parser
.
add_argument
(
"--data_random_seed"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--skip_relaxation"
,
action
=
"store_true"
,
default
=
False
,
)
parser
.
add_argument
(
"--multimer_ri_gap"
,
type
=
int
,
default
=
200
,
help
=
"""Residue index offset between multiple sequences, if provided"""
)
parser
.
add_argument
(
"--trace_model"
,
action
=
"store_true"
,
default
=
False
,
help
=
"""Whether to convert parts of each model to TorchScript.
Significantly improves runtime at the cost of lengthy
'compilation.' Useful for large batch jobs."""
)
parser
.
add_argument
(
"--subtract_plddt"
,
action
=
"store_true"
,
default
=
False
,
help
=
""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself"""
)
parser
.
add_argument
(
"--long_sequence_inference"
,
action
=
"store_true"
,
default
=
False
,
help
=
"""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
if
(
args
.
param_path
is
None
):
args
.
param_path
=
os
.
path
.
join
(
if
(
args
.
jax_
param_path
is
None
and
args
.
openfold_checkpoint_path
is
None
):
args
.
jax_
param_path
=
os
.
path
.
join
(
"openfold"
,
"resources"
,
"params"
,
"params_"
+
args
.
model_name
+
".npz"
"params_"
+
args
.
config_preset
+
".npz"
)
if
(
args
.
model_device
==
"cpu"
and
torch
.
cuda
.
is_available
()):
...
...
scripts/alignment_db_scripts/create_alignment_db.py
0 → 100644
View file @
39a6d0e6
import
argparse
import
json
import
os
def
main
(
args
):
db_path
=
os
.
path
.
join
(
args
.
output_db_path
,
f
"
{
args
.
output_db_name
}
.db"
)
index_path
=
os
.
path
.
join
(
args
.
output_db_path
,
f
"
{
args
.
output_db_name
}
.index"
)
db_fp
=
open
(
db_path
,
"wb"
)
index
=
{}
db_offset
=
0
for
chain_alignment_dir
in
os
.
listdir
(
args
.
alignment_dir
):
cad_path
=
os
.
path
.
join
(
args
.
alignment_dir
,
chain_alignment_dir
)
for
f
in
os
.
listdir
(
cad_path
):
f_path
=
os
.
path
.
join
(
cad_path
,
f
)
with
open
(
f_path
,
"rb"
)
as
fp
:
file_bytes
=
fp
.
read
()
l
=
len
(
file_bytes
)
file_list
=
index
.
setdefault
(
chain_alignment_dir
,
[])
file_list
.
append
((
f
,
db_offset
,
l
))
db_fp
.
write
(
file_bytes
)
db_offset
+=
l
db_fp
.
close
()
with
open
(
index_path
,
"w"
)
as
fp
:
json
.
dump
(
index
,
fp
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"alignment_dir"
,
type
=
str
,
help
=
"""Path to precomputed alignment directory, with one subdirectory
per chain."""
)
parser
.
add_argument
(
"output_db_path"
,
type
=
str
)
parser
.
add_argument
(
"output_db_name"
,
type
=
str
)
args
=
parser
.
parse_args
()
main
(
args
)
scripts/alignment_db_scripts/unify_alignment_db_indices.py
0 → 100644
View file @
39a6d0e6
import
argparse
import
json
import
os
""" Unifies databases created with create_alignment_db.py """
def
main
(
args
):
super_index
=
{}
for
f
in
os
.
listdir
(
args
.
alignment_db_dir
):
if
(
not
os
.
path
.
splitext
(
f
)[
-
1
]
==
".index"
):
continue
with
open
(
os
.
path
.
join
(
args
.
alignment_db_dir
,
f
),
"r"
)
as
fp
:
index
=
json
.
load
(
fp
)
db_name
=
f
"
{
os
.
path
.
splitext
(
f
)[
0
]
}
.db"
for
k
in
index
:
super_index
[
k
]
=
{
"db"
:
db_name
,
"files"
:
index
[
k
],
}
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"super.index"
),
"w"
)
as
fp
:
json
.
dump
(
super_index
,
fp
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"alignment_db_dir"
,
type
=
str
,
help
=
"Path to directory containing alignment_dbs"
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
,
help
=
"Path in which to output super index"
)
args
=
parser
.
parse_args
()
main
(
args
)
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment