Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
07e64267
Commit
07e64267
authored
Oct 16, 2021
by
Gustaf Ahdritz
Browse files
Standardize code style
parent
de07730f
Changes
60
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2697 additions
and
2159 deletions
+2697
-2159
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+22
-21
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+17
-16
openfold/model/primitives.py
openfold/model/primitives.py
+117
-103
openfold/model/structure_module.py
openfold/model/structure_module.py
+172
-174
openfold/model/template.py
openfold/model/template.py
+90
-89
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+32
-37
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+37
-30
openfold/np/__init__.py
openfold/np/__init__.py
+6
-4
openfold/np/protein.py
openfold/np/protein.py
+203
-187
openfold/np/relax/__init__.py
openfold/np/relax/__init__.py
+6
-5
openfold/np/relax/amber_minimize.py
openfold/np/relax/amber_minimize.py
+504
-443
openfold/np/relax/cleanup.py
openfold/np/relax/cleanup.py
+91
-87
openfold/np/relax/relax.py
openfold/np/relax/relax.py
+60
-53
openfold/np/relax/utils.py
openfold/np/relax/utils.py
+50
-45
openfold/np/residue_constants.py
openfold/np/residue_constants.py
+1035
-652
openfold/utils/__init__.py
openfold/utils/__init__.py
+6
-5
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+112
-86
openfold/utils/deepspeed.py
openfold/utils/deepspeed.py
+23
-21
openfold/utils/exponential_moving_average.py
openfold/utils/exponential_moving_average.py
+28
-24
openfold/utils/feats.py
openfold/utils/feats.py
+86
-77
No files found.
openfold/model/outer_product_mean.py
View file @
07e64267
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
...
@@ -23,17 +23,18 @@ from openfold.utils.tensor_utils import chunk_layer
...
@@ -23,17 +23,18 @@ from openfold.utils.tensor_utils import chunk_layer
class
OuterProductMean
(
nn
.
Module
):
class
OuterProductMean
(
nn
.
Module
):
"""
"""
Implements Algorithm 10.
Implements Algorithm 10.
"""
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
chunk_size
=
4
,
eps
=
1e-3
):
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
chunk_size
=
4
,
eps
=
1e-3
):
"""
"""
Args:
Args:
c_m:
c_m:
MSA embedding channel dimension
MSA embedding channel dimension
c_z:
c_z:
Pair embedding channel dimension
Pair embedding channel dimension
c_hidden:
c_hidden:
Hidden channel dimension
Hidden channel dimension
"""
"""
super
(
OuterProductMean
,
self
).
__init__
()
super
(
OuterProductMean
,
self
).
__init__
()
...
@@ -45,12 +46,12 @@ class OuterProductMean(nn.Module):
...
@@ -45,12 +46,12 @@ class OuterProductMean(nn.Module):
self
.
layer_norm
=
nn
.
LayerNorm
(
c_m
)
self
.
layer_norm
=
nn
.
LayerNorm
(
c_m
)
self
.
linear_1
=
Linear
(
c_m
,
c_hidden
)
self
.
linear_1
=
Linear
(
c_m
,
c_hidden
)
self
.
linear_2
=
Linear
(
c_m
,
c_hidden
)
self
.
linear_2
=
Linear
(
c_m
,
c_hidden
)
self
.
linear_out
=
Linear
(
c_hidden
**
2
,
c_z
,
init
=
"final"
)
self
.
linear_out
=
Linear
(
c_hidden
**
2
,
c_z
,
init
=
"final"
)
def
_opm
(
self
,
a
,
b
):
def
_opm
(
self
,
a
,
b
):
# [*, N_res, N_res, C, C]
# [*, N_res, N_res, C, C]
outer
=
torch
.
einsum
(
"...bac,...dae->...bdce"
,
a
,
b
)
outer
=
torch
.
einsum
(
"...bac,...dae->...bdce"
,
a
,
b
)
# [*, N_res, N_res, C * C]
# [*, N_res, N_res, C * C]
outer
=
outer
.
reshape
(
*
outer
.
shape
[:
-
2
],
-
1
)
outer
=
outer
.
reshape
(
*
outer
.
shape
[:
-
2
],
-
1
)
...
@@ -61,20 +62,20 @@ class OuterProductMean(nn.Module):
...
@@ -61,20 +62,20 @@ class OuterProductMean(nn.Module):
def
forward
(
self
,
m
,
mask
=
None
):
def
forward
(
self
,
m
,
mask
=
None
):
"""
"""
Args:
Args:
m:
m:
[*, N_seq, N_res, C_m] MSA embedding
[*, N_seq, N_res, C_m] MSA embedding
mask:
mask:
[*, N_seq, N_res] MSA mask
[*, N_seq, N_res] MSA mask
Returns:
Returns:
[*, N_res, N_res, C_z] pair embedding update
[*, N_res, N_res, C_z] pair embedding update
"""
"""
if
(
mask
is
None
)
:
if
mask
is
None
:
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, C_m]
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm
(
m
)
m
=
self
.
layer_norm
(
m
)
# [*, N_seq, N_res, C]
# [*, N_seq, N_res, C]
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
a
=
self
.
linear_1
(
m
)
*
mask
a
=
self
.
linear_1
(
m
)
*
mask
...
@@ -83,7 +84,7 @@ class OuterProductMean(nn.Module):
...
@@ -83,7 +84,7 @@ class OuterProductMean(nn.Module):
a
=
a
.
transpose
(
-
2
,
-
3
)
a
=
a
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
if
(
self
.
chunk_size
is
not
None
)
:
if
self
.
chunk_size
is
not
None
:
# Since the "batch dim" in this case is not a true batch dimension
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
# iterate over it ourselves
...
...
openfold/model/pair_transition.py
View file @
07e64267
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
...
@@ -22,16 +22,17 @@ from openfold.utils.tensor_utils import chunk_layer
...
@@ -22,16 +22,17 @@ from openfold.utils.tensor_utils import chunk_layer
class
PairTransition
(
nn
.
Module
):
class
PairTransition
(
nn
.
Module
):
"""
"""
Implements Algorithm 15.
Implements Algorithm 15.
"""
"""
def
__init__
(
self
,
c_z
,
n
,
chunk_size
=
4
):
def
__init__
(
self
,
c_z
,
n
,
chunk_size
=
4
):
"""
"""
Args:
Args:
c_z:
c_z:
Pair transition channel dimension
Pair transition channel dimension
n:
n:
Factor by which c_z is multiplied to obtain hidden channel
Factor by which c_z is multiplied to obtain hidden channel
dimension
dimension
"""
"""
super
(
PairTransition
,
self
).
__init__
()
super
(
PairTransition
,
self
).
__init__
()
...
@@ -56,14 +57,14 @@ class PairTransition(nn.Module):
...
@@ -56,14 +57,14 @@ class PairTransition(nn.Module):
def
forward
(
self
,
z
,
mask
=
None
):
def
forward
(
self
,
z
,
mask
=
None
):
"""
"""
Args:
Args:
z:
z:
[*, N_res, N_res, C_z] pair embedding
[*, N_res, N_res, C_z] pair embedding
Returns:
Returns:
[*, N_res, N_res, C_z] pair embedding update
[*, N_res, N_res, C_z] pair embedding update
"""
"""
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
if
(
mask
is
None
)
:
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
# [*, N_res, N_res, 1]
# [*, N_res, N_res, 1]
...
@@ -73,12 +74,12 @@ class PairTransition(nn.Module):
...
@@ -73,12 +74,12 @@ class PairTransition(nn.Module):
z
=
self
.
layer_norm
(
z
)
z
=
self
.
layer_norm
(
z
)
inp
=
{
"z"
:
z
,
"mask"
:
mask
}
inp
=
{
"z"
:
z
,
"mask"
:
mask
}
if
(
self
.
chunk_size
is
not
None
)
:
if
self
.
chunk_size
is
not
None
:
z
=
chunk_layer
(
z
=
chunk_layer
(
self
.
_transition
,
self
.
_transition
,
inp
,
inp
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
self
.
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
z
.
shape
[:
-
2
]),
)
)
else
:
else
:
z
=
self
.
_transition
(
**
inp
)
z
=
self
.
_transition
(
**
inp
)
...
...
openfold/model/primitives.py
View file @
07e64267
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
from
typing
import
Optional
,
Callable
,
List
from
typing
import
Optional
,
Callable
,
List
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -22,7 +22,7 @@ import torch.nn as nn
...
@@ -22,7 +22,7 @@ import torch.nn as nn
from
scipy.stats
import
truncnorm
from
scipy.stats
import
truncnorm
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
)
)
...
@@ -33,6 +33,7 @@ def _prod(nums):
...
@@ -33,6 +33,7 @@ def _prod(nums):
out
=
out
*
n
out
=
out
*
n
return
out
return
out
def
_calculate_fan
(
shape
,
fan
=
"fan_in"
):
def
_calculate_fan
(
shape
,
fan
=
"fan_in"
):
i
=
shape
[
0
]
i
=
shape
[
0
]
o
=
shape
[
1
]
o
=
shape
[
1
]
...
@@ -40,20 +41,20 @@ def _calculate_fan(shape, fan="fan_in"):
...
@@ -40,20 +41,20 @@ def _calculate_fan(shape, fan="fan_in"):
fan_in
=
prod
*
i
fan_in
=
prod
*
i
fan_out
=
prod
*
o
fan_out
=
prod
*
o
if
(
fan
==
"fan_in"
)
:
if
fan
==
"fan_in"
:
f
=
fan_in
f
=
fan_in
elif
(
fan
==
"fan_out"
)
:
elif
fan
==
"fan_out"
:
f
=
fan_out
f
=
fan_out
elif
(
fan
==
"fan_avg"
)
:
elif
fan
==
"fan_avg"
:
f
=
(
fan_in
+
fan_out
)
/
2
f
=
(
fan_in
+
fan_out
)
/
2
else
:
else
:
raise
ValueError
(
"Invalid fan option"
)
raise
ValueError
(
"Invalid fan option"
)
return
f
return
f
def
trunc_normal_init_
(
weights
,
scale
=
1.0
,
fan
=
"fan_in"
):
def
trunc_normal_init_
(
weights
,
scale
=
1.0
,
fan
=
"fan_in"
):
shape
=
weights
.
shape
shape
=
weights
.
shape
f
=
_calculate_fan
(
shape
,
fan
)
f
=
_calculate_fan
(
shape
,
fan
)
scale
=
scale
/
max
(
1
,
f
)
scale
=
scale
/
max
(
1
,
f
)
a
=
-
2
a
=
-
2
...
@@ -80,17 +81,17 @@ def glorot_uniform_init_(weights):
...
@@ -80,17 +81,17 @@ def glorot_uniform_init_(weights):
def
final_init_
(
weights
):
def
final_init_
(
weights
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
weights
.
fill_
(
0.
)
weights
.
fill_
(
0.
0
)
def
gating_init_
(
weights
):
def
gating_init_
(
weights
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
weights
.
fill_
(
0.
)
weights
.
fill_
(
0.
0
)
def
normal_init_
(
weights
):
def
normal_init_
(
weights
):
torch
.
nn
.
init
.
kaiming_normal_
(
weights
,
nonlinearity
=
"linear"
)
torch
.
nn
.
init
.
kaiming_normal_
(
weights
,
nonlinearity
=
"linear"
)
def
ipa_point_weights_init_
(
weights
):
def
ipa_point_weights_init_
(
weights
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -100,98 +101,101 @@ def ipa_point_weights_init_(weights):
...
@@ -100,98 +101,101 @@ def ipa_point_weights_init_(weights):
class
Linear
(
nn
.
Linear
):
class
Linear
(
nn
.
Linear
):
"""
"""
A Linear layer with built-in nonstandard initializations. Called just
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
in the code.
"""
"""
def
__init__
(
self
,
def
__init__
(
in_dim
:
int
,
self
,
out_dim
:
int
,
in_dim
:
int
,
bias
:
bool
=
True
,
out_dim
:
int
,
init
:
str
=
"default"
,
bias
:
bool
=
True
,
init
:
str
=
"default"
,
init_fn
:
Optional
[
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
None
]]
=
None
,
init_fn
:
Optional
[
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
],
None
]]
=
None
,
):
):
"""
"""
Args:
Args:
in_dim:
in_dim:
The final dimension of inputs to the layer
The final dimension of inputs to the layer
out_dim:
out_dim:
The final dimension of layer outputs
The final dimension of layer outputs
bias:
bias:
Whether to learn an additive bias. True by default
Whether to learn an additive bias. True by default
init:
init:
The initializer to use. Choose from:
The initializer to use. Choose from:
"default": LeCun fan-in truncated normal initialization
"default": LeCun fan-in truncated normal initialization
"relu": He initialization w/ truncated normal distribution
"relu": He initialization w/ truncated normal distribution
"glorot": Fan-average Glorot uniform initialization
"glorot": Fan-average Glorot uniform initialization
"gating": Weights=0, Bias=1
"gating": Weights=0, Bias=1
"normal": Normal initialization with std=1/sqrt(fan_in)
"normal": Normal initialization with std=1/sqrt(fan_in)
"final": Weights=0, Bias=0
"final": Weights=0, Bias=0
Overridden by init_fn if the latter is not None.
Overridden by init_fn if the latter is not None.
init_fn:
init_fn:
A custom initializer taking weight and bias as inputs.
A custom initializer taking weight and bias as inputs.
Overrides init if not None.
Overrides init if not None.
"""
"""
super
(
Linear
,
self
).
__init__
(
in_dim
,
out_dim
,
bias
=
bias
)
super
(
Linear
,
self
).
__init__
(
in_dim
,
out_dim
,
bias
=
bias
)
if
(
bias
)
:
if
bias
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
0
)
self
.
bias
.
fill_
(
0
)
if
(
init_fn
is
not
None
)
:
if
init_fn
is
not
None
:
init_fn
(
self
.
weight
,
self
.
bias
)
init_fn
(
self
.
weight
,
self
.
bias
)
else
:
else
:
if
(
init
==
"default"
)
:
if
init
==
"default"
:
lecun_normal_init_
(
self
.
weight
)
lecun_normal_init_
(
self
.
weight
)
elif
(
init
==
"relu"
)
:
elif
init
==
"relu"
:
he_normal_init_
(
self
.
weight
)
he_normal_init_
(
self
.
weight
)
elif
(
init
==
"glorot"
)
:
elif
init
==
"glorot"
:
glorot_uniform_init_
(
self
.
weight
)
glorot_uniform_init_
(
self
.
weight
)
elif
(
init
==
"gating"
)
:
elif
init
==
"gating"
:
gating_init_
(
self
.
weight
)
gating_init_
(
self
.
weight
)
if
(
bias
)
:
if
bias
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
1.
)
self
.
bias
.
fill_
(
1.
0
)
elif
(
init
==
"normal"
)
:
elif
init
==
"normal"
:
normal_init_
(
self
.
weight
)
normal_init_
(
self
.
weight
)
elif
(
init
==
"final"
)
:
elif
init
==
"final"
:
final_init_
(
self
.
weight
)
final_init_
(
self
.
weight
)
else
:
else
:
raise
ValueError
(
"Invalid init string."
)
raise
ValueError
(
"Invalid init string."
)
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
"""
Standard multi-head attention using AlphaFold's default layer
initialization.
"""
"""
def
__init__
(
self
,
Standard multi-head attention using AlphaFold's default layer
c_q
:
int
,
initialization.
c_k
:
int
,
"""
c_v
:
int
,
c_hidden
:
int
,
def
__init__
(
no_heads
:
int
,
self
,
c_q
:
int
,
c_k
:
int
,
c_v
:
int
,
c_hidden
:
int
,
no_heads
:
int
,
gating
:
bool
=
True
,
gating
:
bool
=
True
,
):
):
"""
"""
Args:
Args:
c_q:
c_q:
Input dimension of query data
Input dimension of query data
c_k:
c_k:
Input dimension of key data
Input dimension of key data
c_v:
c_v:
Input dimension of value data
Input dimension of value data
c_hidden:
c_hidden:
Per-head hidden dimension
Per-head hidden dimension
no_heads:
no_heads:
Number of attention heads
Number of attention heads
gating:
gating:
Whether the output should be gated using query data
Whether the output should be gated using query data
"""
"""
super
(
Attention
,
self
).
__init__
()
super
(
Attention
,
self
).
__init__
()
...
@@ -202,7 +206,7 @@ class Attention(nn.Module):
...
@@ -202,7 +206,7 @@ class Attention(nn.Module):
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
gating
=
gating
self
.
gating
=
gating
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# stated in the supplement, but the overall channel dimension
# stated in the supplement, but the overall channel dimension
self
.
linear_q
=
Linear
(
self
.
linear_q
=
Linear
(
...
@@ -218,28 +222,31 @@ class Attention(nn.Module):
...
@@ -218,28 +222,31 @@ class Attention(nn.Module):
self
.
c_hidden
*
self
.
no_heads
,
self
.
c_q
,
init
=
"final"
self
.
c_hidden
*
self
.
no_heads
,
self
.
c_q
,
init
=
"final"
)
)
if
(
self
.
gating
is
not
None
):
if
self
.
gating
is
not
None
:
self
.
linear_g
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
init
=
"gating"
)
self
.
linear_g
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
init
=
"gating"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
def
forward
(
q_x
:
torch
.
Tensor
,
self
,
k_x
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
,
v_x
:
torch
.
Tensor
,
k_x
:
torch
.
Tensor
,
v_x
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
q_x:
q_x:
[*, Q, C_q] query data
[*, Q, C_q] query data
k_x:
k_x:
[*, K, C_k] key data
[*, K, C_k] key data
v_x:
v_x:
[*, V, C_v] value data
[*, V, C_v] value data
Returns
Returns
[*, Q, C_q] attention update
[*, Q, C_q] attention update
"""
"""
# [*, Q/K/V, H * C_hidden]
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
q
=
self
.
linear_q
(
q_x
)
...
@@ -254,11 +261,11 @@ class Attention(nn.Module):
...
@@ -254,11 +261,11 @@ class Attention(nn.Module):
# [*, H, Q, K]
# [*, H, Q, K]
a
=
torch
.
matmul
(
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
0
,
2
,
1
,
3
)),
# [*, H, Q, C_hidden]
permute_final_dims
(
q
,
(
0
,
2
,
1
,
3
)),
# [*, H, Q, C_hidden]
permute_final_dims
(
k
,
(
0
,
2
,
3
,
1
)),
# [*, H, C_hidden, K]
permute_final_dims
(
k
,
(
0
,
2
,
3
,
1
)),
# [*, H, C_hidden, K]
)
)
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
a
=
a
*
norm
a
=
a
*
norm
if
(
biases
is
not
None
)
:
if
biases
is
not
None
:
for
b
in
biases
:
for
b
in
biases
:
a
=
a
+
b
a
=
a
+
b
a
=
self
.
softmax
(
a
)
a
=
self
.
softmax
(
a
)
...
@@ -271,18 +278,18 @@ class Attention(nn.Module):
...
@@ -271,18 +278,18 @@ class Attention(nn.Module):
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
o
=
o
.
transpose
(
-
2
,
-
3
)
o
=
o
.
transpose
(
-
2
,
-
3
)
if
(
self
.
gating
)
:
if
self
.
gating
:
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
o
=
o
*
g
# [*, Q, H * C_hidden]
# [*, Q, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
o
=
flatten_final_dims
(
o
,
2
)
# [*, Q, C_q]
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
o
=
self
.
linear_o
(
o
)
return
o
return
o
...
@@ -301,10 +308,16 @@ class GlobalAttention(nn.Module):
...
@@ -301,10 +308,16 @@ class GlobalAttention(nn.Module):
)
)
self
.
linear_k
=
Linear
(
self
.
linear_k
=
Linear
(
c_in
,
c_hidden
,
bias
=
False
,
init
=
"glorot"
,
c_in
,
c_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
)
self
.
linear_v
=
Linear
(
self
.
linear_v
=
Linear
(
c_in
,
c_hidden
,
bias
=
False
,
init
=
"glorot"
,
c_in
,
c_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
)
self
.
linear_g
=
Linear
(
c_in
,
c_hidden
*
no_heads
,
init
=
"gating"
)
self
.
linear_g
=
Linear
(
c_in
,
c_hidden
*
no_heads
,
init
=
"gating"
)
self
.
linear_o
=
Linear
(
c_hidden
*
no_heads
,
c_in
,
init
=
"final"
)
self
.
linear_o
=
Linear
(
c_hidden
*
no_heads
,
c_in
,
init
=
"final"
)
...
@@ -314,8 +327,9 @@ class GlobalAttention(nn.Module):
...
@@ -314,8 +327,9 @@ class GlobalAttention(nn.Module):
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
# [*, N_res, C_in]
q
=
(
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
q
=
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
))
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
# [*, N_res, H * C_hidden]
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
q
)
q
=
self
.
linear_q
(
q
)
...
@@ -331,7 +345,7 @@ class GlobalAttention(nn.Module):
...
@@ -331,7 +345,7 @@ class GlobalAttention(nn.Module):
# [*, N_res, H, N_seq]
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
a
=
torch
.
matmul
(
q
,
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
=
a
+
bias
a
=
a
+
bias
...
@@ -351,7 +365,7 @@ class GlobalAttention(nn.Module):
...
@@ -351,7 +365,7 @@ class GlobalAttention(nn.Module):
# [*, N_res, N_seq, H, C_hidden]
# [*, N_res, N_seq, H, C_hidden]
o
=
o
.
unsqueeze
(
-
3
)
*
g
o
=
o
.
unsqueeze
(
-
3
)
*
g
# [*, N_res, N_seq, H * C_hidden]
# [*, N_res, N_seq, H * C_hidden]
o
=
o
.
reshape
(
o
.
shape
[:
-
2
]
+
(
-
1
,))
o
=
o
.
reshape
(
o
.
shape
[:
-
2
]
+
(
-
1
,))
...
...
openfold/model/structure_module.py
View file @
07e64267
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
...
@@ -25,14 +25,14 @@ from openfold.np.residue_constants import (
...
@@ -25,14 +25,14 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
)
)
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.feats
import
(
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
torsion_angles_to_frames
,
)
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
dict_multimap
,
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
)
)
...
@@ -40,9 +40,9 @@ from openfold.utils.tensor_utils import (
...
@@ -40,9 +40,9 @@ from openfold.utils.tensor_utils import (
class
AngleResnetBlock
(
nn
.
Module
):
class
AngleResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_hidden
):
def
__init__
(
self
,
c_hidden
):
"""
"""
Args:
Args:
c_hidden:
c_hidden:
Hidden channel dimension
Hidden channel dimension
"""
"""
super
(
AngleResnetBlock
,
self
).
__init__
()
super
(
AngleResnetBlock
,
self
).
__init__
()
...
@@ -67,21 +67,22 @@ class AngleResnetBlock(nn.Module):
...
@@ -67,21 +67,22 @@ class AngleResnetBlock(nn.Module):
class
AngleResnet
(
nn
.
Module
):
class
AngleResnet
(
nn
.
Module
):
"""
"""
Implements Algorithm 20, lines 11-14
Implements Algorithm 20, lines 11-14
"""
"""
def
__init__
(
self
,
c_in
,
c_hidden
,
no_blocks
,
no_angles
,
epsilon
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_blocks
,
no_angles
,
epsilon
):
"""
"""
Args:
Args:
c_in:
c_in:
Input channel dimension
Input channel dimension
c_hidden:
c_hidden:
Hidden channel dimension
Hidden channel dimension
no_blocks:
no_blocks:
Number of resnet blocks
Number of resnet blocks
no_angles:
no_angles:
Number of torsion angles to generate
Number of torsion angles to generate
epsilon:
epsilon:
Small constant for normalization
Small constant for normalization
"""
"""
super
(
AngleResnet
,
self
).
__init__
()
super
(
AngleResnet
,
self
).
__init__
()
...
@@ -103,22 +104,21 @@ class AngleResnet(nn.Module):
...
@@ -103,22 +104,21 @@ class AngleResnet(nn.Module):
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
def
forward
(
s
:
torch
.
Tensor
,
self
,
s
:
torch
.
Tensor
,
s_initial
:
torch
.
Tensor
s_initial
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Args:
Args:
s:
s:
[*, C_hidden] single embedding
[*, C_hidden] single embedding
s_initial:
s_initial:
[*, C_hidden] single embedding as of the start of the
[*, C_hidden] single embedding as of the start of the
StructureModule
StructureModule
Returns:
Returns:
[*, no_angles, 2] predicted angles
[*, no_angles, 2] predicted angles
"""
"""
# NOTE: The ReLU's applied to the inputs are absent from the supplement
# NOTE: The ReLU's applied to the inputs are absent from the supplement
# pseudocode but present in the source. For maximal compatibility with
# pseudocode but present in the source. For maximal compatibility with
# the pretrained weights, I'm going with the source.
# the pretrained weights, I'm going with the source.
# [*, C_hidden]
# [*, C_hidden]
...
@@ -153,9 +153,11 @@ class AngleResnet(nn.Module):
...
@@ -153,9 +153,11 @@ class AngleResnet(nn.Module):
class
InvariantPointAttention
(
nn
.
Module
):
class
InvariantPointAttention
(
nn
.
Module
):
"""
"""
Implements Algorithm 22.
Implements Algorithm 22.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_s
,
c_s
,
c_z
,
c_z
,
c_hidden
,
c_hidden
,
...
@@ -166,19 +168,19 @@ class InvariantPointAttention(nn.Module):
...
@@ -166,19 +168,19 @@ class InvariantPointAttention(nn.Module):
eps
=
1e-8
,
eps
=
1e-8
,
):
):
"""
"""
Args:
Args:
c_s:
c_s:
Single representation channel dimension
Single representation channel dimension
c_z:
c_z:
Pair representation channel dimension
Pair representation channel dimension
c_hidden:
c_hidden:
Hidden channel dimension
Hidden channel dimension
no_heads:
no_heads:
Number of attention heads
Number of attention heads
no_qk_points:
no_qk_points:
Number of query/key points to generate
Number of query/key points to generate
no_v_points:
no_v_points:
Number of value points to generate
Number of value points to generate
"""
"""
super
(
InvariantPointAttention
,
self
).
__init__
()
super
(
InvariantPointAttention
,
self
).
__init__
()
...
@@ -212,32 +214,33 @@ class InvariantPointAttention(nn.Module):
...
@@ -212,32 +214,33 @@ class InvariantPointAttention(nn.Module):
self
.
head_weights
=
nn
.
Parameter
(
torch
.
zeros
((
no_heads
)))
self
.
head_weights
=
nn
.
Parameter
(
torch
.
zeros
((
no_heads
)))
ipa_point_weights_init_
(
self
.
head_weights
)
ipa_point_weights_init_
(
self
.
head_weights
)
concat_out_dim
=
self
.
no_heads
*
(
self
.
c_z
concat_out_dim
=
self
.
no_heads
*
(
+
self
.
c_hidden
self
.
c_z
+
self
.
c_hidden
+
self
.
no_v_points
*
4
+
self
.
no_v_points
*
4
)
)
self
.
linear_out
=
Linear
(
concat_out_dim
,
self
.
c_s
,
init
=
"final"
)
self
.
linear_out
=
Linear
(
concat_out_dim
,
self
.
c_s
,
init
=
"final"
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
self
.
softplus
=
nn
.
Softplus
()
self
.
softplus
=
nn
.
Softplus
()
def
forward
(
self
,
def
forward
(
s
:
torch
.
Tensor
,
self
,
z
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
t
:
T
,
t
:
T
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
s:
s:
[*, N_res, C_s] single representation
[*, N_res, C_s] single representation
z:
z:
[*, N_res, N_res, C_z] pair representation
[*, N_res, N_res, C_z] pair representation
t:
t:
[*, N_res] affine transformation object
[*, N_res] affine transformation object
mask:
mask:
[*, N_res] mask
[*, N_res] mask
Returns:
Returns:
[*, N_res, C_s] single representation update
[*, N_res, C_s] single representation update
"""
"""
#######################################
#######################################
# Generate scalar and point activations
# Generate scalar and point activations
...
@@ -261,12 +264,12 @@ class InvariantPointAttention(nn.Module):
...
@@ -261,12 +264,12 @@ class InvariantPointAttention(nn.Module):
# This is kind of clunky, but it's how the original does it
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
# [*, N_res, H * P_q, 3]
q_pts
=
torch
.
split
(
q_pts
,
q_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
q_pts
=
torch
.
split
(
q_pts
,
q_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
q_pts
=
torch
.
stack
(
q_pts
,
dim
=-
1
)
q_pts
=
torch
.
stack
(
q_pts
,
dim
=-
1
)
q_pts
=
t
[...,
None
].
apply
(
q_pts
)
q_pts
=
t
[...,
None
].
apply
(
q_pts
)
# [*, N_res, H, P_q, 3]
# [*, N_res, H, P_q, 3]
q_pts
=
q_pts
.
view
(
q_pts
=
q_pts
.
view
(
q_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
self
.
no_qk_points
,
3
)
q_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
self
.
no_qk_points
,
3
)
)
)
# [*, N_res, H * (P_q + P_v) * 3]
# [*, N_res, H * (P_q + P_v) * 3]
...
@@ -278,15 +281,11 @@ class InvariantPointAttention(nn.Module):
...
@@ -278,15 +281,11 @@ class InvariantPointAttention(nn.Module):
kv_pts
=
t
[...,
None
].
apply
(
kv_pts
)
kv_pts
=
t
[...,
None
].
apply
(
kv_pts
)
# [*, N_res, H, (P_q + P_v), 3]
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
kv_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
))
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
)
)
# [*, N_res, H, P_q/P_v, 3]
# [*, N_res, H, P_q/P_v, 3]
k_pts
,
v_pts
=
torch
.
split
(
k_pts
,
v_pts
=
torch
.
split
(
kv_pts
,
kv_pts
,
[
self
.
no_qk_points
,
self
.
no_v_points
],
dim
=-
2
[
self
.
no_qk_points
,
self
.
no_v_points
],
dim
=-
2
)
)
##########################
##########################
...
@@ -298,12 +297,12 @@ class InvariantPointAttention(nn.Module):
...
@@ -298,12 +297,12 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
# [*, H, N_res, N_res]
a
=
torch
.
matmul
(
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
q
,
(
1
,
0
,
2
)),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
permute_final_dims
(
k
,
(
1
,
2
,
0
)),
# [*, H, C_hidden, N_res]
)
)
a
=
a
*
math
.
sqrt
(
1.
/
(
3
*
self
.
c_hidden
))
a
=
a
*
math
.
sqrt
(
1.
0
/
(
3
*
self
.
c_hidden
))
a
=
a
+
(
math
.
sqrt
(
1.
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
=
a
+
(
math
.
sqrt
(
1.
0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
pt_att
**
2
pt_att
=
pt_att
**
2
...
@@ -312,12 +311,12 @@ class InvariantPointAttention(nn.Module):
...
@@ -312,12 +311,12 @@ class InvariantPointAttention(nn.Module):
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
)
head_weights
=
(
head_weights
*
math
.
sqrt
(
1.
/
(
3
*
(
self
.
no_qk_points
*
9.
/
2
)))
)
)
pt_att
=
pt_att
*
head_weights
head_weights
=
head_weights
*
math
.
sqrt
(
1.0
/
(
3
*
(
self
.
no_qk_points
*
9.0
/
2
))
)
pt_att
=
pt_att
*
head_weights
# [*, N_res, N_res, H]
# [*, N_res, N_res, H]
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
*
(
-
0.5
)
pt_att
=
torch
.
sum
(
pt_att
,
dim
=-
1
)
*
(
-
0.5
)
# [*, N_res, N_res]
# [*, N_res, N_res]
...
@@ -345,10 +344,10 @@ class InvariantPointAttention(nn.Module):
...
@@ -345,10 +344,10 @@ class InvariantPointAttention(nn.Module):
# [*, H, 3, N_res, P_v]
# [*, H, 3, N_res, P_v]
o_pt
=
torch
.
sum
(
o_pt
=
torch
.
sum
(
(
(
a
[...,
None
,
:,
:,
None
]
*
a
[...,
None
,
:,
:,
None
]
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
),
dim
=-
2
dim
=-
2
,
)
)
# [*, N_res, H, P_v, 3]
# [*, N_res, H, P_v, 3]
...
@@ -357,8 +356,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -357,8 +356,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v]
# [*, N_res, H * P_v]
o_pt_norm
=
flatten_final_dims
(
o_pt_norm
=
flatten_final_dims
(
torch
.
sqrt
(
torch
.
sum
(
o_pt
**
2
,
dim
=-
1
)
+
self
.
eps
),
torch
.
sqrt
(
torch
.
sum
(
o_pt
**
2
,
dim
=-
1
)
+
self
.
eps
),
2
2
)
)
# [*, N_res, H * P_v, 3]
# [*, N_res, H * P_v, 3]
...
@@ -372,26 +370,24 @@ class InvariantPointAttention(nn.Module):
...
@@ -372,26 +370,24 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s]
# [*, N_res, C_s]
s
=
self
.
linear_out
(
s
=
self
.
linear_out
(
torch
.
cat
((
torch
.
cat
(
o
,
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
*
torch
.
unbind
(
o_pt
,
dim
=-
1
),
o_pt_norm
,
o_pair
),
dim
=-
1
)
)
)
)
return
s
return
s
class
BackboneUpdate
(
nn
.
Module
):
class
BackboneUpdate
(
nn
.
Module
):
"""
"""
Implements Algorithm 23.
Implements Algorithm 23.
"""
"""
def
__init__
(
self
,
c_s
):
def
__init__
(
self
,
c_s
):
"""
"""
Args:
Args:
c_s:
c_s:
Single representation channel dimension
Single representation channel dimension
"""
"""
super
(
BackboneUpdate
,
self
).
__init__
()
super
(
BackboneUpdate
,
self
).
__init__
()
...
@@ -401,24 +397,24 @@ class BackboneUpdate(nn.Module):
...
@@ -401,24 +397,24 @@ class BackboneUpdate(nn.Module):
def
forward
(
self
,
s
):
def
forward
(
self
,
s
):
"""
"""
Args:
Args:
[*, N_res, C_s] single representation
[*, N_res, C_s] single representation
Returns:
Returns:
[*, N_res] affine transformation object
[*, N_res] affine transformation object
"""
"""
# [*, 6]
# [*, 6]
params
=
self
.
linear
(
s
)
params
=
self
.
linear
(
s
)
# [*, 3]
# [*, 3]
quats
,
trans
=
params
[...,:
3
],
params
[...,
3
:]
quats
,
trans
=
params
[...,
:
3
],
params
[...,
3
:]
# [*]
# [*]
#norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
#
norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom
=
torch
.
sqrt
(
torch
.
sum
(
quats
**
2
,
dim
=-
1
)
+
1
)
norm_denom
=
torch
.
sqrt
(
torch
.
sum
(
quats
**
2
,
dim
=-
1
)
+
1
)
# [*, 3]
# [*, 3]
ones
=
(
ones
=
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
)).
expand
(
s
.
new_ones
((
1
,)
*
len
(
quats
.
shape
)).
expand
(
quats
.
shape
[:
-
1
]
+
(
1
,)
)
quats
.
shape
[:
-
1
]
+
(
1
,)
)
)
# [*, 4]
# [*, 4]
...
@@ -436,7 +432,7 @@ class StructureModuleTransitionLayer(nn.Module):
...
@@ -436,7 +432,7 @@ class StructureModuleTransitionLayer(nn.Module):
super
(
StructureModuleTransitionLayer
,
self
).
__init__
()
super
(
StructureModuleTransitionLayer
,
self
).
__init__
()
self
.
c
=
c
self
.
c
=
c
self
.
linear_1
=
Linear
(
self
.
c
,
self
.
c
,
init
=
"relu"
)
self
.
linear_1
=
Linear
(
self
.
c
,
self
.
c
,
init
=
"relu"
)
self
.
linear_2
=
Linear
(
self
.
c
,
self
.
c
,
init
=
"relu"
)
self
.
linear_2
=
Linear
(
self
.
c
,
self
.
c
,
init
=
"relu"
)
self
.
linear_3
=
Linear
(
self
.
c
,
self
.
c
,
init
=
"final"
)
self
.
linear_3
=
Linear
(
self
.
c
,
self
.
c
,
init
=
"final"
)
...
@@ -483,8 +479,9 @@ class StructureModuleTransition(nn.Module):
...
@@ -483,8 +479,9 @@ class StructureModuleTransition(nn.Module):
class
StructureModule
(
nn
.
Module
):
class
StructureModule
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
c_s
,
self
,
c_s
,
c_z
,
c_z
,
c_ipa
,
c_ipa
,
c_resnet
,
c_resnet
,
...
@@ -502,39 +499,39 @@ class StructureModule(nn.Module):
...
@@ -502,39 +499,39 @@ class StructureModule(nn.Module):
**
kwargs
,
**
kwargs
,
):
):
"""
"""
Args:
Args:
c_s:
c_s:
Single representation channel dimension
Single representation channel dimension
c_z:
c_z:
Pair representation channel dimension
Pair representation channel dimension
c_ipa:
c_ipa:
IPA hidden channel dimension
IPA hidden channel dimension
c_resnet:
c_resnet:
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
no_heads_ipa:
no_heads_ipa:
Number of IPA heads
Number of IPA heads
no_qk_points:
no_qk_points:
Number of query/key points to generate during IPA
Number of query/key points to generate during IPA
no_v_points:
no_v_points:
Number of value points to generate during IPA
Number of value points to generate during IPA
dropout_rate:
dropout_rate:
Dropout rate used throughout the layer
Dropout rate used throughout the layer
no_blocks:
no_blocks:
Number of structure module blocks
Number of structure module blocks
no_transition_layers:
no_transition_layers:
Number of layers in the single representation transition
Number of layers in the single representation transition
(Alg. 23 lines 8-9)
(Alg. 23 lines 8-9)
no_resnet_blocks:
no_resnet_blocks:
Number of blocks in the angle resnet
Number of blocks in the angle resnet
no_angles:
no_angles:
Number of angles to generate in the angle resnet
Number of angles to generate in the angle resnet
trans_scale_factor:
trans_scale_factor:
Scale of single representation transition hidden dimension
Scale of single representation transition hidden dimension
epsilon:
epsilon:
Small number used in angle resnet normalization
Small number used in angle resnet normalization
inf:
inf:
Large number used for attention masking
Large number used for attention masking
"""
"""
super
(
StructureModule
,
self
).
__init__
()
super
(
StructureModule
,
self
).
__init__
()
self
.
c_s
=
c_s
self
.
c_s
=
c_s
...
@@ -587,33 +584,34 @@ class StructureModule(nn.Module):
...
@@ -587,33 +584,34 @@ class StructureModule(nn.Module):
self
.
bb_update
=
BackboneUpdate
(
self
.
c_s
)
self
.
bb_update
=
BackboneUpdate
(
self
.
c_s
)
self
.
angle_resnet
=
AngleResnet
(
self
.
angle_resnet
=
AngleResnet
(
self
.
c_s
,
self
.
c_s
,
self
.
c_resnet
,
self
.
c_resnet
,
self
.
no_resnet_blocks
,
self
.
no_resnet_blocks
,
self
.
no_angles
,
self
.
no_angles
,
self
.
epsilon
,
self
.
epsilon
,
)
)
def
forward
(
self
,
def
forward
(
self
,
s
,
s
,
z
,
z
,
f
,
f
,
mask
=
None
,
mask
=
None
,
):
):
"""
"""
Args:
Args:
s:
s:
[*, N_res, C_s] single representation
[*, N_res, C_s] single representation
z:
z:
[*, N_res, N_res, C_z] pair representation
[*, N_res, N_res, C_z] pair representation
f:
f:
[*, N_res] amino acid indices
[*, N_res] amino acid indices
mask:
mask:
Optional [*, N_res] sequence mask
Optional [*, N_res] sequence mask
Returns:
Returns:
A dictionary of outputs
A dictionary of outputs
"""
"""
if
(
mask
is
None
)
:
if
mask
is
None
:
# [*, N]
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
...
@@ -644,7 +642,9 @@ class StructureModule(nn.Module):
...
@@ -644,7 +642,9 @@ class StructureModule(nn.Module):
unnormalized_a
,
a
=
self
.
angle_resnet
(
s
,
s_initial
)
unnormalized_a
,
a
=
self
.
angle_resnet
(
s
,
s_initial
)
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
t
.
scale_translation
(
self
.
trans_scale_factor
),
a
,
f
,
t
.
scale_translation
(
self
.
trans_scale_factor
),
a
,
f
,
)
)
pred_xyz
=
self
.
frames_and_literature_positions_to_atom14_pos
(
pred_xyz
=
self
.
frames_and_literature_positions_to_atom14_pos
(
...
@@ -653,8 +653,7 @@ class StructureModule(nn.Module):
...
@@ -653,8 +653,7 @@ class StructureModule(nn.Module):
)
)
preds
=
{
preds
=
{
"frames"
:
"frames"
:
t
.
scale_translation
(
self
.
trans_scale_factor
).
to_4x4
(),
t
.
scale_translation
(
self
.
trans_scale_factor
).
to_4x4
(),
"sidechain_frames"
:
all_frames_to_global
.
to_4x4
(),
"sidechain_frames"
:
all_frames_to_global
.
to_4x4
(),
"unnormalized_angles"
:
unnormalized_a
,
"unnormalized_angles"
:
unnormalized_a
,
"angles"
:
a
,
"angles"
:
a
,
...
@@ -663,7 +662,7 @@ class StructureModule(nn.Module):
...
@@ -663,7 +662,7 @@ class StructureModule(nn.Module):
outputs
.
append
(
preds
)
outputs
.
append
(
preds
)
if
(
i
<
(
self
.
no_blocks
-
1
)
)
:
if
i
<
(
self
.
no_blocks
-
1
):
t
=
t
.
stop_rot_gradient
()
t
=
t
.
stop_rot_gradient
()
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
...
@@ -672,28 +671,28 @@ class StructureModule(nn.Module):
...
@@ -672,28 +671,28 @@ class StructureModule(nn.Module):
return
outputs
return
outputs
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
if
(
self
.
default_frames
is
None
)
:
if
self
.
default_frames
is
None
:
self
.
default_frames
=
torch
.
tensor
(
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
dtype
=
float_dtype
,
device
=
device
,
device
=
device
,
)
)
if
(
self
.
group_idx
is
None
)
:
if
self
.
group_idx
is
None
:
self
.
group_idx
=
torch
.
tensor
(
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
restype_atom14_to_rigid_group
,
device
=
device
,
device
=
device
,
)
)
if
(
self
.
atom_mask
is
None
)
:
if
self
.
atom_mask
is
None
:
self
.
atom_mask
=
torch
.
tensor
(
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
restype_atom14_mask
,
dtype
=
float_dtype
,
dtype
=
float_dtype
,
device
=
device
,
device
=
device
,
)
)
if
(
self
.
lit_positions
is
None
)
:
if
self
.
lit_positions
is
None
:
self
.
lit_positions
=
torch
.
tensor
(
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
dtype
=
float_dtype
,
device
=
device
,
device
=
device
,
)
)
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
...
@@ -702,17 +701,16 @@ class StructureModule(nn.Module):
...
@@ -702,17 +701,16 @@ class StructureModule(nn.Module):
# Separated purely to make testing less annoying
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
return
torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
def
frames_and_literature_positions_to_atom14_pos
(
t
,
# [*, N, 8]
self
,
t
,
f
# [*, N, 8] # [*, N]
f
# [*, N]
):
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
return
frames_and_literature_positions_to_atom14_pos
(
t
,
t
,
f
,
f
,
self
.
default_frames
,
self
.
default_frames
,
self
.
group_idx
,
self
.
group_idx
,
self
.
atom_mask
,
self
.
atom_mask
,
self
.
lit_positions
,
self
.
lit_positions
,
)
)
openfold/model/template.py
View file @
07e64267
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
...
@@ -18,9 +18,9 @@ import math
...
@@ -18,9 +18,9 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
Attention
from
openfold.model.primitives
import
Linear
,
Attention
from
openfold.utils.deepspeed
import
checkpoint_blocks
from
openfold.utils.deepspeed
import
checkpoint_blocks
from
openfold.model.dropout
import
(
from
openfold.model.dropout
import
(
DropoutRowwise
,
DropoutRowwise
,
DropoutColumnwise
,
DropoutColumnwise
,
)
)
...
@@ -35,35 +35,28 @@ from openfold.model.triangular_multiplicative_update import (
...
@@ -35,35 +35,28 @@ from openfold.model.triangular_multiplicative_update import (
)
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
)
)
class
TemplatePointwiseAttention
(
nn
.
Module
):
class
TemplatePointwiseAttention
(
nn
.
Module
):
"""
"""
Implements Algorithm 17.
Implements Algorithm 17.
"""
"""
def
__init__
(
self
,
c_t
,
def
__init__
(
self
,
c_t
,
c_z
,
c_hidden
,
no_heads
,
chunk_size
,
inf
,
**
kwargs
):
c_z
,
c_hidden
,
no_heads
,
chunk_size
,
inf
,
**
kwargs
):
"""
"""
Args:
Args:
c_t:
c_t:
Template embedding channel dimension
Template embedding channel dimension
c_z:
c_z:
Pair embedding channel dimension
Pair embedding channel dimension
c_hidden:
c_hidden:
Hidden channel dimension
Hidden channel dimension
"""
"""
super
(
TemplatePointwiseAttention
,
self
).
__init__
()
super
(
TemplatePointwiseAttention
,
self
).
__init__
()
self
.
c_t
=
c_t
self
.
c_t
=
c_t
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
...
@@ -72,30 +65,33 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -72,30 +65,33 @@ class TemplatePointwiseAttention(nn.Module):
self
.
inf
=
inf
self
.
inf
=
inf
self
.
mha
=
Attention
(
self
.
mha
=
Attention
(
self
.
c_z
,
self
.
c_t
,
self
.
c_t
,
self
.
c_z
,
self
.
c_hidden
,
self
.
no_heads
,
self
.
c_t
,
self
.
c_t
,
self
.
c_hidden
,
self
.
no_heads
,
gating
=
False
,
gating
=
False
,
)
)
def
forward
(
self
,
t
,
z
,
template_mask
=
None
):
def
forward
(
self
,
t
,
z
,
template_mask
=
None
):
"""
"""
Args:
Args:
t:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
[*, N_templ, N_res, N_res, C_t] template embedding
z:
z:
[*, N_res, N_res, C_t] pair embedding
[*, N_res, N_res, C_t] pair embedding
template_mask:
template_mask:
[*, N_templ] template mask
[*, N_templ] template mask
Returns:
Returns:
[*, N_res, N_res, C_z] pair embedding update
[*, N_res, N_res, C_z] pair embedding update
"""
"""
if
(
template_mask
is
None
)
:
if
template_mask
is
None
:
# NOTE: This is not the "template_mask" from the supplement, but a
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just
# [*, N_templ] mask from the code. I'm pretty sure it's always just
# 1, but not sure enough to remove it. It's nice to have, I guess.
# 1, but not sure enough to remove it. It's nice to have, I guess.
template_mask
=
t
.
new_ones
(
t
.
shape
[:
-
3
])
template_mask
=
t
.
new_ones
(
t
.
shape
[:
-
3
])
bias
=
(
self
.
inf
*
(
template_mask
[...,
None
,
None
,
None
,
None
,
:]
-
1
)
)
bias
=
self
.
inf
*
(
template_mask
[...,
None
,
None
,
None
,
None
,
:]
-
1
)
# [*, N_res, N_res, 1, C_z]
# [*, N_res, N_res, 1, C_z]
z
=
z
.
unsqueeze
(
-
2
)
z
=
z
.
unsqueeze
(
-
2
)
...
@@ -110,36 +106,37 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -110,36 +106,37 @@ class TemplatePointwiseAttention(nn.Module):
"v_x"
:
t
,
"v_x"
:
t
,
"biases"
:
[
bias
],
"biases"
:
[
bias
],
}
}
if
(
self
.
chunk_size
is
not
None
)
:
if
self
.
chunk_size
is
not
None
:
z
=
chunk_layer
(
z
=
chunk_layer
(
self
.
mha
,
self
.
mha
,
mha_inputs
,
mha_inputs
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
self
.
chunk_size
,
no_batch_dims
=
len
(
z
.
shape
[:
-
2
])
no_batch_dims
=
len
(
z
.
shape
[:
-
2
])
,
)
)
else
:
else
:
z
=
self
.
mha
(
**
mha_inputs
)
z
=
self
.
mha
(
**
mha_inputs
)
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
z
.
squeeze
(
-
2
)
z
=
z
.
squeeze
(
-
2
)
return
z
return
z
class
TemplatePairStackBlock
(
nn
.
Module
):
class
TemplatePairStackBlock
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
c_t
,
self
,
c_t
,
c_hidden_tri_att
,
c_hidden_tri_att
,
c_hidden_tri_mul
,
c_hidden_tri_mul
,
no_heads
,
no_heads
,
pair_transition_n
,
pair_transition_n
,
dropout_rate
,
dropout_rate
,
chunk_size
,
chunk_size
,
inf
,
inf
,
**
kwargs
,
**
kwargs
,
):
):
super
(
TemplatePairStackBlock
,
self
).
__init__
()
super
(
TemplatePairStackBlock
,
self
).
__init__
()
self
.
c_t
=
c_t
self
.
c_t
=
c_t
self
.
c_hidden_tri_att
=
c_hidden_tri_att
self
.
c_hidden_tri_att
=
c_hidden_tri_att
self
.
c_hidden_tri_mul
=
c_hidden_tri_mul
self
.
c_hidden_tri_mul
=
c_hidden_tri_mul
...
@@ -151,11 +148,11 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -151,11 +148,11 @@ class TemplatePairStackBlock(nn.Module):
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_col
=
DropoutColumnwise
(
self
.
dropout_rate
)
self
.
dropout_col
=
DropoutColumnwise
(
self
.
dropout_rate
)
self
.
tri_att_start
=
TriangleAttentionStartingNode
(
self
.
tri_att_start
=
TriangleAttentionStartingNode
(
self
.
c_t
,
self
.
c_t
,
self
.
c_hidden_tri_att
,
self
.
c_hidden_tri_att
,
self
.
no_heads
,
self
.
no_heads
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
...
@@ -188,21 +185,23 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -188,21 +185,23 @@ class TemplatePairStackBlock(nn.Module):
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_out
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_out
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_in
(
z
,
mask
=
mask
))
z
=
z
+
self
.
dropout_row
(
self
.
tri_mul_in
(
z
,
mask
=
mask
))
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
mask
if
_mask_trans
else
None
)
z
=
z
+
self
.
pair_transition
(
z
,
mask
=
mask
if
_mask_trans
else
None
)
return
z
return
z
class
TemplatePairStack
(
nn
.
Module
):
class
TemplatePairStack
(
nn
.
Module
):
"""
"""
Implements Algorithm 16.
Implements Algorithm 16.
"""
"""
def
__init__
(
self
,
c_t
,
def
__init__
(
self
,
c_t
,
c_hidden_tri_att
,
c_hidden_tri_att
,
c_hidden_tri_mul
,
c_hidden_tri_mul
,
no_blocks
,
no_blocks
,
no_heads
,
no_heads
,
pair_transition_n
,
pair_transition_n
,
dropout_rate
,
dropout_rate
,
blocks_per_ckpt
,
blocks_per_ckpt
,
chunk_size
,
chunk_size
,
...
@@ -210,26 +209,26 @@ class TemplatePairStack(nn.Module):
...
@@ -210,26 +209,26 @@ class TemplatePairStack(nn.Module):
**
kwargs
,
**
kwargs
,
):
):
"""
"""
Args:
Args:
c_t:
c_t:
Template embedding channel dimension
Template embedding channel dimension
c_hidden_tri_att:
c_hidden_tri_att:
Per-head hidden dimension for triangular attention
Per-head hidden dimension for triangular attention
c_hidden_tri_att:
c_hidden_tri_att:
Hidden dimension for triangular multiplication
Hidden dimension for triangular multiplication
no_blocks:
no_blocks:
Number of blocks in the stack
Number of blocks in the stack
pair_transition_n:
pair_transition_n:
Scale of pair transition (Alg. 15) hidden dimension
Scale of pair transition (Alg. 15) hidden dimension
dropout_rate:
dropout_rate:
Dropout rate used throughout the stack
Dropout rate used throughout the stack
blocks_per_ckpt:
blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables
Number of blocks per activation checkpoint. None disables
activation checkpointing
activation checkpointing
chunk_size:
chunk_size:
Size of subbatches. A higher value increases throughput at
Size of subbatches. A higher value increases throughput at
the cost of memory
the cost of memory
"""
"""
super
(
TemplatePairStack
,
self
).
__init__
()
super
(
TemplatePairStack
,
self
).
__init__
()
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
blocks_per_ckpt
=
blocks_per_ckpt
...
@@ -250,28 +249,30 @@ class TemplatePairStack(nn.Module):
...
@@ -250,28 +249,30 @@ class TemplatePairStack(nn.Module):
self
.
layer_norm
=
nn
.
LayerNorm
(
c_t
)
self
.
layer_norm
=
nn
.
LayerNorm
(
c_t
)
def
forward
(
self
,
def
forward
(
self
,
t
:
torch
.
tensor
,
t
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
):
):
"""
"""
Args:
Args:
t:
t:
[*, N_res, N_res, C_t] template embedding
[*, N_res, N_res, C_t] template embedding
mask:
mask:
[*, N_res, N_res] mask
[*, N_res, N_res] mask
Returns:
Returns:
[*, N_res, N_res, C_t] template embedding update
[*, N_res, N_res, C_t] template embedding update
"""
"""
t
,
=
checkpoint_blocks
(
(
t
,
)
=
checkpoint_blocks
(
blocks
=
[
blocks
=
[
partial
(
partial
(
b
,
b
,
mask
=
mask
,
mask
=
mask
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
)
],
for
b
in
self
.
blocks
],
args
=
(
t
,),
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
...
...
openfold/model/triangular_attention.py
View file @
07e64267
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
...
@@ -20,29 +20,24 @@ import torch.nn as nn
...
@@ -20,29 +20,24 @@ import torch.nn as nn
from
openfold.model.primitives
import
Linear
,
Attention
from
openfold.model.primitives
import
Linear
,
Attention
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
)
)
class
TriangleAttention
(
nn
.
Module
):
class
TriangleAttention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
c_in
,
self
,
c_in
,
c_hidden
,
no_heads
,
starting
,
chunk_size
=
4
,
inf
=
1e9
c_hidden
,
no_heads
,
starting
,
chunk_size
=
4
,
inf
=
1e9
):
):
"""
"""
Args:
Args:
c_in:
c_in:
Input channel dimension
Input channel dimension
c_hidden:
c_hidden:
Overall hidden channel dimension (not per-head)
Overall hidden channel dimension (not per-head)
no_heads:
no_heads:
Number of attention heads
Number of attention heads
"""
"""
super
(
TriangleAttention
,
self
).
__init__
()
super
(
TriangleAttention
,
self
).
__init__
()
...
@@ -54,40 +49,38 @@ class TriangleAttention(nn.Module):
...
@@ -54,40 +49,38 @@ class TriangleAttention(nn.Module):
self
.
inf
=
inf
self
.
inf
=
inf
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
linear
=
Linear
(
c_in
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
linear
=
Linear
(
c_in
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
mha
=
Attention
(
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
self
.
c_hidden
,
self
.
no_heads
)
)
def
forward
(
self
,
x
,
mask
=
None
):
def
forward
(
self
,
x
,
mask
=
None
):
"""
"""
Args:
Args:
x:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
Returns:
[*, I, J, C_in] output tensor
[*, I, J, C_in] output tensor
"""
"""
if
(
mask
is
None
)
:
if
mask
is
None
:
# [*, I, J]
# [*, I, J]
mask
=
x
.
new_ones
(
mask
=
x
.
new_ones
(
x
.
shape
[:
-
1
],
x
.
shape
[:
-
1
],
)
)
# Shape annotations assume self.starting. Else, I and J are flipped
# Shape annotations assume self.starting. Else, I and J are flipped
if
(
not
self
.
starting
)
:
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
x
=
x
.
transpose
(
-
2
,
-
3
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
# [*, I, J, C_in]
# [*, I, J, C_in]
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
# [*, I, 1, 1, J]
# [*, I, 1, 1, J]
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# [*, H, I, J]
# [*, H, I, J]
triangle_bias
=
permute_final_dims
(
self
.
linear
(
x
),
(
2
,
0
,
1
))
triangle_bias
=
permute_final_dims
(
self
.
linear
(
x
),
(
2
,
0
,
1
))
...
@@ -100,17 +93,17 @@ class TriangleAttention(nn.Module):
...
@@ -100,17 +93,17 @@ class TriangleAttention(nn.Module):
"v_x"
:
x
,
"v_x"
:
x
,
"biases"
:
[
mask_bias
,
triangle_bias
],
"biases"
:
[
mask_bias
,
triangle_bias
],
}
}
if
(
self
.
chunk_size
is
not
None
)
:
if
self
.
chunk_size
is
not
None
:
x
=
chunk_layer
(
x
=
chunk_layer
(
self
.
mha
,
self
.
mha
,
mha_inputs
,
mha_inputs
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
self
.
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
])
no_batch_dims
=
len
(
x
.
shape
[:
-
2
])
,
)
)
else
:
else
:
x
=
self
.
mha
(
**
mha_inputs
)
x
=
self
.
mha
(
**
mha_inputs
)
if
(
not
self
.
starting
)
:
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
x
=
x
.
transpose
(
-
2
,
-
3
)
return
x
return
x
...
@@ -118,13 +111,15 @@ class TriangleAttention(nn.Module):
...
@@ -118,13 +111,15 @@ class TriangleAttention(nn.Module):
class
TriangleAttentionStartingNode
(
TriangleAttention
):
class
TriangleAttentionStartingNode
(
TriangleAttention
):
"""
"""
Implements Algorithm 13.
Implements Algorithm 13.
"""
"""
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
True
)
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
True
)
class
TriangleAttentionEndingNode
(
TriangleAttention
):
class
TriangleAttentionEndingNode
(
TriangleAttention
):
"""
"""
Implements Algorithm 14.
Implements Algorithm 14.
"""
"""
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
False
)
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
False
)
openfold/model/triangular_multiplicative_update.py
View file @
07e64267
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
...
@@ -23,16 +23,17 @@ from openfold.utils.tensor_utils import permute_final_dims
...
@@ -23,16 +23,17 @@ from openfold.utils.tensor_utils import permute_final_dims
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
"""
"""
Implements Algorithms 11 and 12.
Implements Algorithms 11 and 12.
"""
"""
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
def
__init__
(
self
,
c_z
,
c_hidden
,
_outgoing
=
True
):
"""
"""
Args:
Args:
c_z:
c_z:
Input channel dimension
Input channel dimension
c:
c:
Hidden channel dimension
Hidden channel dimension
"""
"""
super
(
TriangleMultiplicativeUpdate
,
self
).
__init__
()
super
(
TriangleMultiplicativeUpdate
,
self
).
__init__
()
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
...
@@ -53,22 +54,24 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -53,22 +54,24 @@ class TriangleMultiplicativeUpdate(nn.Module):
cp
=
self
.
_outgoing_matmul
if
self
.
_outgoing
else
self
.
_incoming_matmul
cp
=
self
.
_outgoing_matmul
if
self
.
_outgoing
else
self
.
_incoming_matmul
self
.
combine_projections
=
cp
self
.
combine_projections
=
cp
def
_outgoing_matmul
(
self
,
def
_outgoing_matmul
(
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
self
,
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
):
# [*, C, N_i, N_j]
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
)
)
# [*, N_i, N_j, C]
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
def
_incoming_matmul
(
self
,
def
_incoming_matmul
(
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
self
,
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
):
# [*, C, N_i, N_j]
# [*, C, N_i, N_j]
...
@@ -76,21 +79,21 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -76,21 +79,21 @@ class TriangleMultiplicativeUpdate(nn.Module):
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
)
)
# [*, N_i, N_j, C]
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
def
forward
(
self
,
z
,
mask
=
None
):
def
forward
(
self
,
z
,
mask
=
None
):
"""
"""
Args:
Args:
x:
x:
[*, N_res, N_res, C_z] input tensor
[*, N_res, N_res, C_z] input tensor
mask:
mask:
[*, N_res, N_res] input mask
[*, N_res, N_res] input mask
Returns:
Returns:
[*, N_res, N_res, C_z] output tensor
[*, N_res, N_res, C_z] output tensor
"""
"""
if
(
mask
is
None
)
:
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
...
@@ -111,17 +114,21 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -111,17 +114,21 @@ class TriangleMultiplicativeUpdate(nn.Module):
class
TriangleMultiplicationOutgoing
(
TriangleMultiplicativeUpdate
):
class
TriangleMultiplicationOutgoing
(
TriangleMultiplicativeUpdate
):
"""
"""
Implements Algorithm 11.
Implements Algorithm 11.
"""
"""
__init__
=
partialmethod
(
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
True
,
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
True
,
)
)
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
"""
"""
Implements Algorithm 12.
Implements Algorithm 12.
"""
"""
__init__
=
partialmethod
(
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
,
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
,
)
)
openfold/np/__init__.py
View file @
07e64267
...
@@ -3,12 +3,14 @@ import glob
...
@@ -3,12 +3,14 @@ import glob
import
importlib
as
importlib
import
importlib
as
importlib
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)]
__all__
=
[
_modules
=
[(
m
,
importlib
.
import_module
(
'.'
+
m
,
__name__
))
for
m
in
__all__
]
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
del
_files
,
_m
,
_modules
openfold/np/protein.py
View file @
07e64267
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
...
@@ -27,204 +27,220 @@ ModelOutput = Mapping[str, Any] # Is a nested dict.
...
@@ -27,204 +27,220 @@ ModelOutput = Mapping[str, Any] # Is a nested dict.
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Protein
:
class
Protein
:
"""Protein structure representation."""
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions
:
np
.
ndarray
# [num_res, num_atom_type, 3]
atom_positions
:
np
.
ndarray
# [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
# 20, where 20 is 'X'.
aatype
:
np
.
ndarray
# [num_res]
aatype
:
np
.
ndarray
# [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
# is present and 0.0 if not. This should be used for loss masking.
atom_mask
:
np
.
ndarray
# [num_res, num_atom_type]
atom_mask
:
np
.
ndarray
# [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index
:
np
.
ndarray
# [num_res]
residue_index
:
np
.
ndarray
# [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# representing the displacement of the residue from its ground truth mean
# value.
# value.
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
def
from_pdb_string
(
pdb_str
:
str
,
chain_id
:
Optional
[
str
]
=
None
)
->
Protein
:
"""Takes a PDB string and constructs a Protein object.
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
non-standard atoms will be ignored.
Args:
Args:
pdb_str: The contents of the pdb file
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
is parsed.
Returns:
Returns:
A new `Protein` parsed from the pdb contents.
A new `Protein` parsed from the pdb contents.
"""
"""
pdb_fh
=
io
.
StringIO
(
pdb_str
)
pdb_fh
=
io
.
StringIO
(
pdb_str
)
parser
=
PDBParser
(
QUIET
=
True
)
parser
=
PDBParser
(
QUIET
=
True
)
structure
=
parser
.
get_structure
(
'none'
,
pdb_fh
)
structure
=
parser
.
get_structure
(
"none"
,
pdb_fh
)
models
=
list
(
structure
.
get_models
())
models
=
list
(
structure
.
get_models
())
if
len
(
models
)
!=
1
:
if
len
(
models
)
!=
1
:
raise
ValueError
(
raise
ValueError
(
f
'Only single model PDBs are supported. Found
{
len
(
models
)
}
models.'
)
f
"Only single model PDBs are supported. Found
{
len
(
models
)
}
models."
model
=
models
[
0
]
)
model
=
models
[
0
]
if
chain_id
is
not
None
:
chain
=
model
[
chain_id
]
if
chain_id
is
not
None
:
else
:
chain
=
model
[
chain_id
]
chains
=
list
(
model
.
get_chains
())
if
len
(
chains
)
!=
1
:
raise
ValueError
(
'Only single chain PDBs are supported when chain_id not specified. '
f
'Found
{
len
(
chains
)
}
chains.'
)
else
:
else
:
chain
=
chains
[
0
]
chains
=
list
(
model
.
get_chains
())
if
len
(
chains
)
!=
1
:
atom_positions
=
[]
raise
ValueError
(
aatype
=
[]
"Only single chain PDBs are supported when chain_id not specified. "
atom_mask
=
[]
f
"Found
{
len
(
chains
)
}
chains."
residue_index
=
[]
)
b_factors
=
[]
else
:
chain
=
chains
[
0
]
for
res
in
chain
:
if
res
.
id
[
2
]
!=
' '
:
atom_positions
=
[]
raise
ValueError
(
aatype
=
[]
f
'PDB contains an insertion code at chain
{
chain
.
id
}
and residue '
atom_mask
=
[]
f
'index
{
res
.
id
[
1
]
}
. These are not supported.'
)
residue_index
=
[]
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
'X'
)
b_factors
=
[]
restype_idx
=
residue_constants
.
restype_order
.
get
(
res_shortname
,
residue_constants
.
restype_num
)
for
res
in
chain
:
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
if
res
.
id
[
2
]
!=
" "
:
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
raise
ValueError
(
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
f
"PDB contains an insertion code at chain
{
chain
.
id
}
and residue "
for
atom
in
res
:
f
"index
{
res
.
id
[
1
]
}
. These are not supported."
if
atom
.
name
not
in
residue_constants
.
atom_types
:
)
continue
res_shortname
=
residue_constants
.
restype_3to1
.
get
(
res
.
resname
,
"X"
)
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
restype_idx
=
residue_constants
.
restype_order
.
get
(
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.
res_shortname
,
residue_constants
.
restype_num
res_b_factors
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
bfactor
)
if
np
.
sum
(
mask
)
<
0.5
:
pos
=
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
))
# If no known atom positions are reported for the residue then skip it.
mask
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
continue
res_b_factors
=
np
.
zeros
((
residue_constants
.
atom_type_num
,))
aatype
.
append
(
restype_idx
)
for
atom
in
res
:
atom_positions
.
append
(
pos
)
if
atom
.
name
not
in
residue_constants
.
atom_types
:
atom_mask
.
append
(
mask
)
continue
residue_index
.
append
(
res
.
id
[
1
])
pos
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
atom
.
coord
b_factors
.
append
(
res_b_factors
)
mask
[
residue_constants
.
atom_order
[
atom
.
name
]]
=
1.0
res_b_factors
[
return
Protein
(
residue_constants
.
atom_order
[
atom
.
name
]
atom_positions
=
np
.
array
(
atom_positions
),
]
=
atom
.
bfactor
atom_mask
=
np
.
array
(
atom_mask
),
if
np
.
sum
(
mask
)
<
0.5
:
aatype
=
np
.
array
(
aatype
),
# If no known atom positions are reported for the residue then skip it.
residue_index
=
np
.
array
(
residue_index
),
continue
b_factors
=
np
.
array
(
b_factors
))
aatype
.
append
(
restype_idx
)
atom_positions
.
append
(
pos
)
atom_mask
.
append
(
mask
)
residue_index
.
append
(
res
.
id
[
1
])
b_factors
.
append
(
res_b_factors
)
return
Protein
(
atom_positions
=
np
.
array
(
atom_positions
),
atom_mask
=
np
.
array
(
atom_mask
),
aatype
=
np
.
array
(
aatype
),
residue_index
=
np
.
array
(
residue_index
),
b_factors
=
np
.
array
(
b_factors
),
)
def
to_pdb
(
prot
:
Protein
)
->
str
:
def
to_pdb
(
prot
:
Protein
)
->
str
:
"""Converts a `Protein` instance to a PDB string.
"""Converts a `Protein` instance to a PDB string.
Args:
Args:
prot: The protein to convert to PDB.
prot: The protein to convert to PDB.
Returns:
Returns:
PDB string.
PDB string.
"""
"""
restypes
=
residue_constants
.
restypes
+
[
'X'
]
restypes
=
residue_constants
.
restypes
+
[
"X"
]
res_1to3
=
lambda
r
:
residue_constants
.
restype_1to3
.
get
(
restypes
[
r
],
'UNK'
)
res_1to3
=
lambda
r
:
residue_constants
.
restype_1to3
.
get
(
restypes
[
r
],
"UNK"
)
atom_types
=
residue_constants
.
atom_types
atom_types
=
residue_constants
.
atom_types
pdb_lines
=
[]
pdb_lines
=
[]
atom_mask
=
prot
.
atom_mask
atom_mask
=
prot
.
atom_mask
aatype
=
prot
.
aatype
aatype
=
prot
.
aatype
atom_positions
=
prot
.
atom_positions
atom_positions
=
prot
.
atom_positions
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
residue_index
=
prot
.
residue_index
.
astype
(
np
.
int32
)
b_factors
=
prot
.
b_factors
b_factors
=
prot
.
b_factors
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
if
np
.
any
(
aatype
>
residue_constants
.
restype_num
):
raise
ValueError
(
'Invalid aatypes.'
)
raise
ValueError
(
"Invalid aatypes."
)
pdb_lines
.
append
(
'MODEL 1'
)
pdb_lines
.
append
(
"MODEL 1"
)
atom_index
=
1
atom_index
=
1
chain_id
=
'A'
chain_id
=
"A"
# Add all atom sites.
# Add all atom sites.
for
i
in
range
(
aatype
.
shape
[
0
]):
for
i
in
range
(
aatype
.
shape
[
0
]):
res_name_3
=
res_1to3
(
aatype
[
i
])
res_name_3
=
res_1to3
(
aatype
[
i
])
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
for
atom_name
,
pos
,
mask
,
b_factor
in
zip
(
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]):
atom_types
,
atom_positions
[
i
],
atom_mask
[
i
],
b_factors
[
i
]
if
mask
<
0.5
:
):
continue
if
mask
<
0.5
:
continue
record_type
=
'ATOM'
name
=
atom_name
if
len
(
atom_name
)
==
4
else
f
'
{
atom_name
}
'
record_type
=
"ATOM"
alt_loc
=
''
name
=
atom_name
if
len
(
atom_name
)
==
4
else
f
"
{
atom_name
}
"
insertion_code
=
''
alt_loc
=
""
occupancy
=
1.00
insertion_code
=
""
element
=
atom_name
[
0
]
# Protein supports only C, N, O, S, this works.
occupancy
=
1.00
charge
=
''
element
=
atom_name
[
# PDB is a columnar format, every space matters here!
0
atom_line
=
(
f
'
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
'
]
# Protein supports only C, N, O, S, this works.
f
'
{
res_name_3
:
>
3
}
{
chain_id
:
>
1
}
'
charge
=
""
f
'
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
'
# PDB is a columnar format, every space matters here!
f
'
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
'
atom_line
=
(
f
'
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
'
f
"
{
record_type
:
<
6
}{
atom_index
:
>
5
}
{
name
:
<
4
}{
alt_loc
:
>
1
}
"
f
'
{
element
:
>
2
}{
charge
:
>
2
}
'
)
f
"
{
res_name_3
:
>
3
}
{
chain_id
:
>
1
}
"
pdb_lines
.
append
(
atom_line
)
f
"
{
residue_index
[
i
]:
>
4
}{
insertion_code
:
>
1
}
"
atom_index
+=
1
f
"
{
pos
[
0
]:
>
8.3
f
}{
pos
[
1
]:
>
8.3
f
}{
pos
[
2
]:
>
8.3
f
}
"
f
"
{
occupancy
:
>
6.2
f
}{
b_factor
:
>
6.2
f
}
"
# Close the chain.
f
"
{
element
:
>
2
}{
charge
:
>
2
}
"
chain_end
=
'TER'
)
chain_termination_line
=
(
pdb_lines
.
append
(
atom_line
)
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
res_1to3
(
aatype
[
-
1
]):
>
3
}
'
atom_index
+=
1
f
'
{
chain_id
:
>
1
}{
residue_index
[
-
1
]:
>
4
}
'
)
pdb_lines
.
append
(
chain_termination_line
)
# Close the chain.
pdb_lines
.
append
(
'ENDMDL'
)
chain_end
=
"TER"
chain_termination_line
=
(
pdb_lines
.
append
(
'END'
)
f
"
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
res_1to3
(
aatype
[
-
1
]):
>
3
}
"
pdb_lines
.
append
(
''
)
f
"
{
chain_id
:
>
1
}{
residue_index
[
-
1
]:
>
4
}
"
return
'
\n
'
.
join
(
pdb_lines
)
)
pdb_lines
.
append
(
chain_termination_line
)
pdb_lines
.
append
(
"ENDMDL"
)
pdb_lines
.
append
(
"END"
)
pdb_lines
.
append
(
""
)
return
"
\n
"
.
join
(
pdb_lines
)
def
ideal_atom_mask
(
prot
:
Protein
)
->
np
.
ndarray
:
def
ideal_atom_mask
(
prot
:
Protein
)
->
np
.
ndarray
:
"""Computes an ideal atom mask.
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
that should be present in the given sequence of amino acids.
Args:
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
Returns:
An ideal atom mask.
An ideal atom mask.
"""
"""
return
residue_constants
.
STANDARD_ATOM_MASK
[
prot
.
aatype
]
return
residue_constants
.
STANDARD_ATOM_MASK
[
prot
.
aatype
]
def
from_prediction
(
features
:
FeatureDict
,
result
:
ModelOutput
,
def
from_prediction
(
b_factors
:
Optional
[
np
.
ndarray
]
=
None
)
->
Protein
:
features
:
FeatureDict
,
"""Assembles a protein from a prediction.
result
:
ModelOutput
,
b_factors
:
Optional
[
np
.
ndarray
]
=
None
,
Args:
)
->
Protein
:
features: Dictionary holding model inputs.
"""Assembles a protein from a prediction.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
Args:
features: Dictionary holding model inputs.
Returns:
result: Dictionary holding model outputs.
A protein instance.
b_factors: (Optional) B-factors to use for the protein.
"""
if
b_factors
is
None
:
Returns:
b_factors
=
np
.
zeros_like
(
result
[
'final_atom_mask'
])
A protein instance.
"""
return
Protein
(
if
b_factors
is
None
:
aatype
=
features
[
'aatype'
],
b_factors
=
np
.
zeros_like
(
result
[
"final_atom_mask"
])
atom_positions
=
result
[
'final_atom_positions'
],
atom_mask
=
result
[
'final_atom_mask'
],
return
Protein
(
residue_index
=
features
[
'residue_index'
]
+
1
,
aatype
=
features
[
"aatype"
],
b_factors
=
b_factors
atom_positions
=
result
[
"final_atom_positions"
],
)
atom_mask
=
result
[
"final_atom_mask"
],
residue_index
=
features
[
"residue_index"
]
+
1
,
b_factors
=
b_factors
,
)
openfold/np/relax/__init__.py
View file @
07e64267
...
@@ -3,13 +3,14 @@ import glob
...
@@ -3,13 +3,14 @@ import glob
import
importlib
as
importlib
import
importlib
as
importlib
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)]
__all__
=
[
_modules
=
[(
m
,
importlib
.
import_module
(
'.'
+
m
,
__name__
))
for
m
in
__all__
]
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
del
_files
,
_m
,
_modules
openfold/np/relax/amber_minimize.py
View file @
07e64267
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
...
@@ -38,12 +38,12 @@ LENGTH = unit.angstroms
...
@@ -38,12 +38,12 @@ LENGTH = unit.angstroms
def
will_restrain
(
atom
:
openmm_app
.
Atom
,
rset
:
str
)
->
bool
:
def
will_restrain
(
atom
:
openmm_app
.
Atom
,
rset
:
str
)
->
bool
:
"""Returns True if the atom will be restrained by the given restraint set."""
"""Returns True if the atom will be restrained by the given restraint set."""
if
rset
==
"non_hydrogen"
:
if
rset
==
"non_hydrogen"
:
return
atom
.
element
.
name
!=
"hydrogen"
return
atom
.
element
.
name
!=
"hydrogen"
elif
rset
==
"c_alpha"
:
elif
rset
==
"c_alpha"
:
return
atom
.
name
==
"CA"
return
atom
.
name
==
"CA"
def
_add_restraints
(
def
_add_restraints
(
...
@@ -51,24 +51,29 @@ def _add_restraints(
...
@@ -51,24 +51,29 @@ def _add_restraints(
reference_pdb
:
openmm_app
.
PDBFile
,
reference_pdb
:
openmm_app
.
PDBFile
,
stiffness
:
unit
.
Unit
,
stiffness
:
unit
.
Unit
,
rset
:
str
,
rset
:
str
,
exclude_residues
:
Sequence
[
int
]):
exclude_residues
:
Sequence
[
int
],
"""Adds a harmonic potential that restrains the system to a structure."""
):
assert
rset
in
[
"non_hydrogen"
,
"c_alpha"
]
"""Adds a harmonic potential that restrains the system to a structure."""
assert
rset
in
[
"non_hydrogen"
,
"c_alpha"
]
force
=
openmm
.
CustomExternalForce
(
"0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)"
)
force
=
openmm
.
CustomExternalForce
(
force
.
addGlobalParameter
(
"k"
,
stiffness
)
"0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)"
for
p
in
[
"x0"
,
"y0"
,
"z0"
]:
)
force
.
addPerParticleParameter
(
p
)
force
.
addGlobalParameter
(
"k"
,
stiffness
)
for
p
in
[
"x0"
,
"y0"
,
"z0"
]:
for
i
,
atom
in
enumerate
(
reference_pdb
.
topology
.
atoms
()):
force
.
addPerParticleParameter
(
p
)
if
atom
.
residue
.
index
in
exclude_residues
:
continue
for
i
,
atom
in
enumerate
(
reference_pdb
.
topology
.
atoms
()):
if
will_restrain
(
atom
,
rset
):
if
atom
.
residue
.
index
in
exclude_residues
:
force
.
addParticle
(
i
,
reference_pdb
.
positions
[
i
])
continue
logging
.
info
(
"Restraining %d / %d particles."
,
if
will_restrain
(
atom
,
rset
):
force
.
getNumParticles
(),
system
.
getNumParticles
())
force
.
addParticle
(
i
,
reference_pdb
.
positions
[
i
])
system
.
addForce
(
force
)
logging
.
info
(
"Restraining %d / %d particles."
,
force
.
getNumParticles
(),
system
.
getNumParticles
(),
)
system
.
addForce
(
force
)
def
_openmm_minimize
(
def
_openmm_minimize
(
...
@@ -77,291 +82,324 @@ def _openmm_minimize(
...
@@ -77,291 +82,324 @@ def _openmm_minimize(
tolerance
:
unit
.
Unit
,
tolerance
:
unit
.
Unit
,
stiffness
:
unit
.
Unit
,
stiffness
:
unit
.
Unit
,
restraint_set
:
str
,
restraint_set
:
str
,
exclude_residues
:
Sequence
[
int
]
):
exclude_residues
:
Sequence
[
int
]
,
"""Minimize energy via openmm."""
):
"""Minimize energy via openmm."""
pdb_file
=
io
.
StringIO
(
pdb_str
)
pdb
=
openmm_app
.
PDBFile
(
pdb_
file
)
pdb
_file
=
io
.
StringIO
(
pdb_
str
)
pdb
=
openmm_app
.
PDBFile
(
pdb_file
)
force_field
=
openmm_app
.
ForceField
(
"amber99sb.xml"
)
constraints
=
openmm_app
.
HBonds
force_field
=
openmm_app
.
ForceField
(
"amber99sb.xml"
)
system
=
force_field
.
createSystem
(
constraints
=
openmm_app
.
HBonds
pdb
.
topology
,
constraints
=
constraints
)
system
=
force_field
.
createSystem
(
pdb
.
topology
,
constraints
=
constraints
)
if
stiffness
>
0
*
ENERGY
/
(
LENGTH
**
2
):
if
stiffness
>
0
*
ENERGY
/
(
LENGTH
**
2
):
_add_restraints
(
system
,
pdb
,
stiffness
,
restraint_set
,
exclude_residues
)
_add_restraints
(
system
,
pdb
,
stiffness
,
restraint_set
,
exclude_residues
)
integrator
=
openmm
.
LangevinIntegrator
(
0
,
0.01
,
0.0
)
integrator
=
openmm
.
LangevinIntegrator
(
0
,
0.01
,
0.0
)
platform
=
openmm
.
Platform
.
getPlatformByName
(
"CPU"
)
platform
=
openmm
.
Platform
.
getPlatformByName
(
"CPU"
)
simulation
=
openmm_app
.
Simulation
(
simulation
=
openmm_app
.
Simulation
(
pdb
.
topology
,
system
,
integrator
,
platform
)
pdb
.
topology
,
system
,
integrator
,
platform
simulation
.
context
.
setPositions
(
pdb
.
positions
)
)
simulation
.
context
.
setPositions
(
pdb
.
positions
)
ret
=
{}
state
=
simulation
.
context
.
getState
(
getEnergy
=
True
,
getPositions
=
True
)
ret
=
{}
ret
[
"einit"
]
=
state
.
getPotentialEnergy
().
value_in_unit
(
ENERGY
)
state
=
simulation
.
context
.
getState
(
getEnergy
=
True
,
getPositions
=
True
)
ret
[
"
pos
init"
]
=
state
.
getPo
sitions
(
asNumpy
=
True
).
value_in_unit
(
L
EN
GTH
)
ret
[
"
e
init"
]
=
state
.
getPo
tentialEnergy
(
).
value_in_unit
(
EN
ERGY
)
simulation
.
minimizeEnergy
(
maxIterations
=
max_iterations
,
ret
[
"posinit"
]
=
state
.
getPositions
(
asNumpy
=
True
).
value_in_unit
(
LENGTH
)
tolerance
=
tolerance
)
simulation
.
minimizeEnergy
(
maxIterations
=
max_iterations
,
tolerance
=
tolerance
)
state
=
simulation
.
context
.
getState
(
getEnergy
=
True
,
getPositions
=
True
)
state
=
simulation
.
context
.
getState
(
getEnergy
=
True
,
getPositions
=
True
)
ret
[
"efinal"
]
=
state
.
getPotentialEnergy
().
value_in_unit
(
ENERGY
)
ret
[
"efinal"
]
=
state
.
getPotentialEnergy
().
value_in_unit
(
ENERGY
)
ret
[
"pos"
]
=
state
.
getPositions
(
asNumpy
=
True
).
value_in_unit
(
LENGTH
)
ret
[
"pos"
]
=
state
.
getPositions
(
asNumpy
=
True
).
value_in_unit
(
LENGTH
)
ret
[
"min_pdb"
]
=
_get_pdb_string
(
simulation
.
topology
,
state
.
getPositions
())
ret
[
"min_pdb"
]
=
_get_pdb_string
(
simulation
.
topology
,
state
.
getPositions
())
return
ret
return
ret
def
_get_pdb_string
(
topology
:
openmm_app
.
Topology
,
positions
:
unit
.
Quantity
):
def
_get_pdb_string
(
topology
:
openmm_app
.
Topology
,
positions
:
unit
.
Quantity
):
"""Returns a pdb string provided OpenMM topology and positions."""
"""Returns a pdb string provided OpenMM topology and positions."""
with
io
.
StringIO
()
as
f
:
with
io
.
StringIO
()
as
f
:
openmm_app
.
PDBFile
.
writeFile
(
topology
,
positions
,
f
)
openmm_app
.
PDBFile
.
writeFile
(
topology
,
positions
,
f
)
return
f
.
getvalue
()
return
f
.
getvalue
()
def
_check_cleaned_atoms
(
pdb_cleaned_string
:
str
,
pdb_ref_string
:
str
):
def
_check_cleaned_atoms
(
pdb_cleaned_string
:
str
,
pdb_ref_string
:
str
):
"""Checks that no atom positions have been altered by cleaning."""
"""Checks that no atom positions have been altered by cleaning."""
cleaned
=
openmm_app
.
PDBFile
(
io
.
StringIO
(
pdb_cleaned_string
))
cleaned
=
openmm_app
.
PDBFile
(
io
.
StringIO
(
pdb_cleaned_string
))
reference
=
openmm_app
.
PDBFile
(
io
.
StringIO
(
pdb_ref_string
))
reference
=
openmm_app
.
PDBFile
(
io
.
StringIO
(
pdb_ref_string
))
cl_xyz
=
np
.
array
(
cleaned
.
getPositions
().
value_in_unit
(
LENGTH
))
cl_xyz
=
np
.
array
(
cleaned
.
getPositions
().
value_in_unit
(
LENGTH
))
ref_xyz
=
np
.
array
(
reference
.
getPositions
().
value_in_unit
(
LENGTH
))
ref_xyz
=
np
.
array
(
reference
.
getPositions
().
value_in_unit
(
LENGTH
))
for
ref_res
,
cl_res
in
zip
(
reference
.
topology
.
residues
(),
for
ref_res
,
cl_res
in
zip
(
cleaned
.
topology
.
residues
()):
reference
.
topology
.
residues
(),
cleaned
.
topology
.
residues
()
assert
ref_res
.
name
==
cl_res
.
name
):
for
rat
in
ref_res
.
atoms
():
assert
ref_res
.
name
==
cl_res
.
name
for
cat
in
cl_res
.
atoms
():
for
rat
in
ref_res
.
atoms
():
if
cat
.
name
==
rat
.
name
:
for
cat
in
cl_res
.
atoms
():
if
not
np
.
array_equal
(
cl_xyz
[
cat
.
index
],
ref_xyz
[
rat
.
index
]):
if
cat
.
name
==
rat
.
name
:
raise
ValueError
(
f
"Coordinates of cleaned atom
{
cat
}
do not match "
if
not
np
.
array_equal
(
f
"coordinates of reference atom
{
rat
}
."
)
cl_xyz
[
cat
.
index
],
ref_xyz
[
rat
.
index
]
):
raise
ValueError
(
f
"Coordinates of cleaned atom
{
cat
}
do not match "
f
"coordinates of reference atom
{
rat
}
."
)
def
_check_residues_are_well_defined
(
prot
:
protein
.
Protein
):
def
_check_residues_are_well_defined
(
prot
:
protein
.
Protein
):
"""Checks that all residues contain non-empty atom sets."""
"""Checks that all residues contain non-empty atom sets."""
if
(
prot
.
atom_mask
.
sum
(
axis
=-
1
)
==
0
).
any
():
if
(
prot
.
atom_mask
.
sum
(
axis
=-
1
)
==
0
).
any
():
raise
ValueError
(
"Amber minimization can only be performed on proteins with"
raise
ValueError
(
" well-defined residues. This protein contains at least"
"Amber minimization can only be performed on proteins with"
" one residue with no atoms."
)
" well-defined residues. This protein contains at least"
" one residue with no atoms."
)
def
_check_atom_mask_is_ideal
(
prot
):
def
_check_atom_mask_is_ideal
(
prot
):
"""Sanity-check the atom mask is ideal, up to a possible OXT."""
"""Sanity-check the atom mask is ideal, up to a possible OXT."""
atom_mask
=
prot
.
atom_mask
atom_mask
=
prot
.
atom_mask
ideal_atom_mask
=
protein
.
ideal_atom_mask
(
prot
)
ideal_atom_mask
=
protein
.
ideal_atom_mask
(
prot
)
utils
.
assert_equal_nonterminal_atom_types
(
atom_mask
,
ideal_atom_mask
)
utils
.
assert_equal_nonterminal_atom_types
(
atom_mask
,
ideal_atom_mask
)
def
clean_protein
(
def
clean_protein
(
prot
:
protein
.
Protein
,
checks
:
bool
=
True
):
prot
:
protein
.
Protein
,
"""Adds missing atoms to Protein instance.
checks
:
bool
=
True
):
"""Adds missing atoms to Protein instance.
Args:
prot: A `protein.Protein` instance.
Args:
checks: A `bool` specifying whether to add additional checks to the cleaning
prot: A `protein.Protein` instance.
process.
checks: A `bool` specifying whether to add additional checks to the cleaning
process.
Returns:
pdb_string: A string of the cleaned protein.
Returns:
"""
pdb_string: A string of the cleaned protein.
_check_atom_mask_is_ideal
(
prot
)
"""
_check_atom_mask_is_ideal
(
prot
)
# Clean pdb.
prot_pdb_string
=
protein
.
to_pdb
(
prot
)
# Clean pdb.
pdb_file
=
io
.
StringIO
(
prot_pdb_string
)
prot_pdb_string
=
protein
.
to_pdb
(
prot
)
alterations_info
=
{}
pdb_file
=
io
.
StringIO
(
prot_pdb_string
)
fixed_pdb
=
cleanup
.
fix_pdb
(
pdb_file
,
alterations_info
)
alterations_info
=
{}
fixed_pdb_file
=
io
.
StringIO
(
fixed_pdb
)
fixed_pdb
=
cleanup
.
fix_pdb
(
pdb_file
,
alterations_info
)
pdb_structure
=
PdbStructure
(
fixed_pdb_file
)
fixed_pdb_file
=
io
.
StringIO
(
fixed_pdb
)
cleanup
.
clean_structure
(
pdb_structure
,
alterations_info
)
pdb_structure
=
PdbStructure
(
fixed_pdb_file
)
cleanup
.
clean_structure
(
pdb_structure
,
alterations_info
)
logging
.
info
(
"alterations info: %s"
,
alterations_info
)
logging
.
info
(
"alterations info: %s"
,
alterations_info
)
# Write pdb file of cleaned structure.
as_file
=
openmm_app
.
PDBFile
(
pdb_structure
)
# Write pdb file of cleaned structure.
pdb_string
=
_get_pdb_string
(
as_file
.
getTopology
(),
as_file
.
getPositions
())
as_file
=
openmm_app
.
PDBFile
(
pdb_structure
)
if
checks
:
pdb_string
=
_get_pdb_string
(
as_file
.
getTopology
(),
as_file
.
getPositions
())
_check_cleaned_atoms
(
pdb_string
,
prot_pdb_string
)
if
checks
:
return
pdb_string
_check_cleaned_atoms
(
pdb_string
,
prot_pdb_string
)
return
pdb_string
def
make_atom14_positions
(
prot
):
def
make_atom14_positions
(
prot
):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
"""Constructs denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37
=
[]
# mapping (restype, atom14) --> atom37
restype_atom14_to_atom37
=
[]
# mapping (restype, atom14) --> atom37
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
restype_atom14_mask
=
[]
restype_atom14_mask
=
[]
for
rt
in
residue_constants
.
restypes
:
for
rt
in
residue_constants
.
restypes
:
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]]
residue_constants
.
restype_1to3
[
rt
]
]
restype_atom14_to_atom37
.
append
([
(
residue_constants
.
atom_order
[
name
]
if
name
else
0
)
restype_atom14_to_atom37
.
append
(
for
name
in
atom_names
[
])
(
residue_constants
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
]
restype_atom37_to_atom14
.
append
([
)
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
residue_constants
.
atom_types
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
])
restype_atom37_to_atom14
.
append
(
[
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
residue_constants
.
atom_types
# Add dummy mapping for restype 'UNK'.
]
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_mask
.
append
(
[(
1.0
if
name
else
0.0
)
for
name
in
atom_names
]
restype_atom14_to_atom37
=
np
.
array
(
restype_atom14_to_atom37
,
dtype
=
np
.
int32
)
)
restype_atom37_to_atom14
=
np
.
array
(
restype_atom37_to_atom14
,
dtype
=
np
.
int32
)
restype_atom14_mask
=
np
.
array
(
restype_atom14_mask
,
dtype
=
np
.
float32
)
# Add dummy mapping for restype 'UNK'.
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
# Create the mapping for (residx, atom14) --> atom37, i.e. an array
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
# with shape (num_res, 14) containing the atom37 indices for this protein.
restype_atom14_mask
.
append
([
0.0
]
*
14
)
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
prot
[
"aatype"
]]
residx_atom14_mask
=
restype_atom14_mask
[
prot
[
"aatype"
]]
restype_atom14_to_atom37
=
np
.
array
(
restype_atom14_to_atom37
,
dtype
=
np
.
int32
# Create a mask for known ground truth positions.
)
residx_atom14_gt_mask
=
residx_atom14_mask
*
np
.
take_along_axis
(
restype_atom37_to_atom14
=
np
.
array
(
prot
[
"all_atom_mask"
],
residx_atom14_to_atom37
,
axis
=
1
).
astype
(
np
.
float32
)
restype_atom37_to_atom14
,
dtype
=
np
.
int32
)
# Gather the ground truth positions.
restype_atom14_mask
=
np
.
array
(
restype_atom14_mask
,
dtype
=
np
.
float32
)
residx_atom14_gt_positions
=
residx_atom14_gt_mask
[:,
:,
None
]
*
(
np
.
take_along_axis
(
prot
[
"all_atom_positions"
],
# Create the mapping for (residx, atom14) --> atom37, i.e. an array
residx_atom14_to_atom37
[...,
None
],
# with shape (num_res, 14) containing the atom37 indices for this protein.
axis
=
1
))
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
prot
[
"aatype"
]]
residx_atom14_mask
=
restype_atom14_mask
[
prot
[
"aatype"
]]
prot
[
"atom14_atom_exists"
]
=
residx_atom14_mask
prot
[
"atom14_gt_exists"
]
=
residx_atom14_gt_mask
# Create a mask for known ground truth positions.
prot
[
"atom14_gt_positions"
]
=
residx_atom14_gt_positions
residx_atom14_gt_mask
=
residx_atom14_mask
*
np
.
take_along_axis
(
prot
[
"all_atom_mask"
],
residx_atom14_to_atom37
,
axis
=
1
prot
[
"residx_atom14_to_atom37"
]
=
residx_atom14_to_atom37
.
astype
(
np
.
int64
)
).
astype
(
np
.
float32
)
# Create the gather indices for mapping back.
# Gather the ground truth positions.
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
prot
[
"aatype"
]]
residx_atom14_gt_positions
=
residx_atom14_gt_mask
[:,
:,
None
]
*
(
prot
[
"residx_atom37_to_atom14"
]
=
residx_atom37_to_atom14
.
astype
(
np
.
int64
)
np
.
take_along_axis
(
prot
[
"all_atom_positions"
],
# Create the corresponding mask.
residx_atom14_to_atom37
[...,
None
],
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
axis
=
1
,
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
)
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
)
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
prot
[
"atom14_atom_exists"
]
=
residx_atom14_mask
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
prot
[
"atom14_gt_exists"
]
=
residx_atom14_gt_mask
restype_atom37_mask
[
restype
,
atom_type
]
=
1
prot
[
"atom14_gt_positions"
]
=
residx_atom14_gt_positions
residx_atom37_mask
=
restype_atom37_mask
[
prot
[
"aatype"
]]
prot
[
"residx_atom14_to_atom37"
]
=
residx_atom14_to_atom37
.
astype
(
np
.
int64
)
prot
[
"atom37_atom_exists"
]
=
residx_atom37_mask
# Create the gather indices for mapping back.
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
prot
[
"aatype"
]]
# alternative ground truth coordinates where the naming is swapped
prot
[
"residx_atom37_to_atom14"
]
=
residx_atom37_to_atom14
.
astype
(
np
.
int64
)
restype_3
=
[
residue_constants
.
restype_1to3
[
res
]
for
res
in
residue_constants
.
restypes
# Create the corresponding mask.
]
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
restype_3
+=
[
"UNK"
]
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
# Matrices for renaming ambiguous atoms.
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
all_matrices
=
{
res
:
np
.
eye
(
14
,
dtype
=
np
.
float32
)
for
res
in
restype_3
}
for
atom_name
in
atom_names
:
for
resname
,
swap
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
correspondences
=
np
.
arange
(
14
)
restype_atom37_mask
[
restype
,
atom_type
]
=
1
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
source_index
=
residue_constants
.
restype_name_to_atom14_names
[
residx_atom37_mask
=
restype_atom37_mask
[
prot
[
"aatype"
]]
resname
].
index
(
source_atom_swap
)
prot
[
"atom37_atom_exists"
]
=
residx_atom37_mask
target_index
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
target_atom_swap
)
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
correspondences
[
source_index
]
=
target_index
# alternative ground truth coordinates where the naming is swapped
correspondences
[
target_index
]
=
source_index
restype_3
=
[
renaming_matrix
=
np
.
zeros
((
14
,
14
),
dtype
=
np
.
float32
)
residue_constants
.
restype_1to3
[
res
]
for
index
,
correspondence
in
enumerate
(
correspondences
):
for
res
in
residue_constants
.
restypes
renaming_matrix
[
index
,
correspondence
]
=
1.
]
all_matrices
[
resname
]
=
renaming_matrix
.
astype
(
np
.
float32
)
restype_3
+=
[
"UNK"
]
renaming_matrices
=
np
.
stack
([
all_matrices
[
restype
]
for
restype
in
restype_3
])
# Matrices for renaming ambiguous atoms.
# Pick the transformation matrices for the given residue sequence
all_matrices
=
{
res
:
np
.
eye
(
14
,
dtype
=
np
.
float32
)
for
res
in
restype_3
}
# shape (num_res, 14, 14).
for
resname
,
swap
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
renaming_transform
=
renaming_matrices
[
prot
[
"aatype"
]]
correspondences
=
np
.
arange
(
14
)
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
# Apply it to the ground truth positions. shape (num_res, 14, 3).
source_index
=
residue_constants
.
restype_name_to_atom14_names
[
alternative_gt_positions
=
np
.
einsum
(
"rac,rab->rbc"
,
resname
residx_atom14_gt_positions
,
].
index
(
source_atom_swap
)
renaming_transform
)
target_index
=
residue_constants
.
restype_name_to_atom14_names
[
prot
[
"atom14_alt_gt_positions"
]
=
alternative_gt_positions
resname
].
index
(
target_atom_swap
)
# Create the mask for the alternative ground truth (differs from the
correspondences
[
source_index
]
=
target_index
# ground truth mask, if only one of the atoms in an ambiguous pair has a
correspondences
[
target_index
]
=
source_index
# ground truth position).
renaming_matrix
=
np
.
zeros
((
14
,
14
),
dtype
=
np
.
float32
)
alternative_gt_mask
=
np
.
einsum
(
"ra,rab->rb"
,
for
index
,
correspondence
in
enumerate
(
correspondences
):
residx_atom14_gt_mask
,
renaming_matrix
[
index
,
correspondence
]
=
1.0
renaming_transform
)
all_matrices
[
resname
]
=
renaming_matrix
.
astype
(
np
.
float32
)
renaming_matrices
=
np
.
stack
(
prot
[
"atom14_alt_gt_exists"
]
=
alternative_gt_mask
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
)
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
# Pick the transformation matrices for the given residue sequence
for
resname
,
swap
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
# shape (num_res, 14, 14).
for
atom_name1
,
atom_name2
in
swap
.
items
():
renaming_transform
=
renaming_matrices
[
prot
[
"aatype"
]]
restype
=
residue_constants
.
restype_order
[
residue_constants
.
restype_3to1
[
resname
]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
atom_idx1
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
alternative_gt_positions
=
np
.
einsum
(
atom_name1
)
"rac,rab->rbc"
,
residx_atom14_gt_positions
,
renaming_transform
atom_idx2
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
)
atom_name2
)
prot
[
"atom14_alt_gt_positions"
]
=
alternative_gt_positions
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# From this create an ambiguous_mask for the given sequence.
# ground truth position).
prot
[
"atom14_atom_is_ambiguous"
]
=
(
alternative_gt_mask
=
np
.
einsum
(
restype_atom14_is_ambiguous
[
prot
[
"aatype"
]])
"ra,rab->rb"
,
residx_atom14_gt_mask
,
renaming_transform
)
return
prot
prot
[
"atom14_alt_gt_exists"
]
=
alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
for
resname
,
swap
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
for
atom_name1
,
atom_name2
in
swap
.
items
():
restype
=
residue_constants
.
restype_order
[
residue_constants
.
restype_3to1
[
resname
]
]
atom_idx1
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name1
)
atom_idx2
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name2
)
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
# From this create an ambiguous_mask for the given sequence.
prot
[
"atom14_atom_is_ambiguous"
]
=
restype_atom14_is_ambiguous
[
prot
[
"aatype"
]
]
return
prot
def
find_violations
(
prot_np
:
protein
.
Protein
):
def
find_violations
(
prot_np
:
protein
.
Protein
):
"""Analyzes a protein and returns structural violation information.
"""Analyzes a protein and returns structural violation information.
Args:
Args:
prot_np: A protein.
prot_np: A protein.
Returns:
Returns:
violations: A `dict` of structure components with structural violations.
violations: A `dict` of structure components with structural violations.
violation_metrics: A `dict` of violation metrics.
violation_metrics: A `dict` of violation metrics.
"""
"""
batch
=
{
batch
=
{
"aatype"
:
prot_np
.
aatype
,
"aatype"
:
prot_np
.
aatype
,
"all_atom_positions"
:
prot_np
.
atom_positions
.
astype
(
np
.
float32
),
"all_atom_positions"
:
prot_np
.
atom_positions
.
astype
(
np
.
float32
),
"all_atom_mask"
:
prot_np
.
atom_mask
.
astype
(
np
.
float32
),
"all_atom_mask"
:
prot_np
.
atom_mask
.
astype
(
np
.
float32
),
"residue_index"
:
prot_np
.
residue_index
,
"residue_index"
:
prot_np
.
residue_index
,
}
}
batch
[
"seq_mask"
]
=
np
.
ones_like
(
batch
[
"aatype"
],
np
.
float32
)
batch
[
"seq_mask"
]
=
np
.
ones_like
(
batch
[
"aatype"
],
np
.
float32
)
batch
=
make_atom14_positions
(
batch
)
batch
=
make_atom14_positions
(
batch
)
violations
=
loss
.
find_structural_violations_np
(
violations
=
loss
.
find_structural_violations_np
(
batch
=
batch
,
batch
=
batch
,
atom14_pred_positions
=
batch
[
"atom14_gt_positions"
],
atom14_pred_positions
=
batch
[
"atom14_gt_positions"
],
config
=
ml_collections
.
ConfigDict
(
config
=
ml_collections
.
ConfigDict
(
{
"violation_tolerance_factor"
:
12
,
# Taken from model config.
{
"clash_overlap_tolerance"
:
1.5
,
# Taken from model config.
"violation_tolerance_factor"
:
12
,
# Taken from model config.
}))
"clash_overlap_tolerance"
:
1.5
,
# Taken from model config.
violation_metrics
=
loss
.
compute_violation_metrics_np
(
}
batch
=
batch
,
),
atom14_pred_positions
=
batch
[
"atom14_gt_positions"
],
)
violations
=
violations
,
violation_metrics
=
loss
.
compute_violation_metrics_np
(
)
batch
=
batch
,
atom14_pred_positions
=
batch
[
"atom14_gt_positions"
],
return
violations
,
violation_metrics
violations
=
violations
,
)
return
violations
,
violation_metrics
def
get_violation_metrics
(
prot
:
protein
.
Protein
):
def
get_violation_metrics
(
prot
:
protein
.
Protein
):
"""Computes violation and alignment metrics."""
"""Computes violation and alignment metrics."""
structural_violations
,
struct_metrics
=
find_violations
(
prot
)
structural_violations
,
struct_metrics
=
find_violations
(
prot
)
violation_idx
=
np
.
flatnonzero
(
violation_idx
=
np
.
flatnonzero
(
structural_violations
[
"total_per_residue_violations_mask"
])
structural_violations
[
"total_per_residue_violations_mask"
]
)
struct_metrics
[
"residue_violations"
]
=
violation_idx
struct_metrics
[
"residue_violations"
]
=
violation_idx
struct_metrics
[
"num_residue_violations"
]
=
len
(
violation_idx
)
struct_metrics
[
"num_residue_violations"
]
=
len
(
violation_idx
)
struct_metrics
[
"structural_violations"
]
=
structural_violations
struct_metrics
[
"structural_violations"
]
=
structural_violations
return
struct_metrics
return
struct_metrics
def
_run_one_iteration
(
def
_run_one_iteration
(
...
@@ -372,51 +410,56 @@ def _run_one_iteration(
...
@@ -372,51 +410,56 @@ def _run_one_iteration(
stiffness
:
float
,
stiffness
:
float
,
restraint_set
:
str
,
restraint_set
:
str
,
max_attempts
:
int
,
max_attempts
:
int
,
exclude_residues
:
Optional
[
Collection
[
int
]]
=
None
):
exclude_residues
:
Optional
[
Collection
[
int
]]
=
None
,
"""Runs the minimization pipeline.
):
"""Runs the minimization pipeline.
Args:
pdb_string: A pdb string.
Args:
max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
pdb_string: A pdb string.
A value of 0 specifies no limit.
max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
A value of 0 specifies no limit.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
tolerance: kcal/mol, the energy tolerance of L-BFGS.
potential.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
restraint_set: The set of atoms to restrain.
potential.
max_attempts: The maximum number of minimization attempts.
restraint_set: The set of atoms to restrain.
exclude_residues: An optional list of zero-indexed residues to exclude from
max_attempts: The maximum number of minimization attempts.
restraints.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A `dict` of minimization info.
Returns:
"""
A `dict` of minimization info.
exclude_residues
=
exclude_residues
or
[]
"""
exclude_residues
=
exclude_residues
or
[]
# Assign physical dimensions.
tolerance
=
tolerance
*
ENERGY
# Assign physical dimensions.
stiffness
=
stiffness
*
ENERGY
/
(
LENGTH
**
2
)
tolerance
=
tolerance
*
ENERGY
stiffness
=
stiffness
*
ENERGY
/
(
LENGTH
**
2
)
start
=
time
.
time
()
minimized
=
False
start
=
time
.
time
()
attempts
=
0
minimized
=
False
while
not
minimized
and
attempts
<
max_attempts
:
attempts
=
0
attempts
+=
1
while
not
minimized
and
attempts
<
max_attempts
:
try
:
attempts
+=
1
logging
.
info
(
"Minimizing protein, attempt %d of %d."
,
try
:
attempts
,
max_attempts
)
logging
.
info
(
ret
=
_openmm_minimize
(
"Minimizing protein, attempt %d of %d."
,
attempts
,
max_attempts
pdb_string
,
max_iterations
=
max_iterations
,
)
tolerance
=
tolerance
,
stiffness
=
stiffness
,
ret
=
_openmm_minimize
(
restraint_set
=
restraint_set
,
pdb_string
,
exclude_residues
=
exclude_residues
)
max_iterations
=
max_iterations
,
minimized
=
True
tolerance
=
tolerance
,
except
Exception
as
e
:
# pylint: disable=broad-except
stiffness
=
stiffness
,
logging
.
info
(
e
)
restraint_set
=
restraint_set
,
if
not
minimized
:
exclude_residues
=
exclude_residues
,
raise
ValueError
(
f
"Minimization failed after
{
max_attempts
}
attempts."
)
)
ret
[
"opt_time"
]
=
time
.
time
()
-
start
minimized
=
True
ret
[
"min_attempts"
]
=
attempts
except
Exception
as
e
:
# pylint: disable=broad-except
return
ret
logging
.
info
(
e
)
if
not
minimized
:
raise
ValueError
(
f
"Minimization failed after
{
max_attempts
}
attempts."
)
ret
[
"opt_time"
]
=
time
.
time
()
-
start
ret
[
"min_attempts"
]
=
attempts
return
ret
def
run_pipeline
(
def
run_pipeline
(
...
@@ -429,116 +472,134 @@ def run_pipeline(
...
@@ -429,116 +472,134 @@ def run_pipeline(
restraint_set
:
str
=
"non_hydrogen"
,
restraint_set
:
str
=
"non_hydrogen"
,
max_attempts
:
int
=
100
,
max_attempts
:
int
=
100
,
checks
:
bool
=
True
,
checks
:
bool
=
True
,
exclude_residues
:
Optional
[
Sequence
[
int
]]
=
None
):
exclude_residues
:
Optional
[
Sequence
[
int
]]
=
None
,
"""Run iterative amber relax.
):
"""Run iterative amber relax.
Successive relax iterations are performed until all violations have been
resolved. Each iteration involves a restrained Amber minimization, with
Successive relax iterations are performed until all violations have been
restraint exclusions determined by violation-participating residues.
resolved. Each iteration involves a restrained Amber minimization, with
restraint exclusions determined by violation-participating residues.
Args:
prot: A protein to be relaxed.
Args:
stiffness: kcal/mol A**2, the restraint stiffness.
prot: A protein to be relaxed.
max_outer_iterations: The maximum number of iterative minimization.
stiffness: kcal/mol A**2, the restraint stiffness.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
max_outer_iterations: The maximum number of iterative minimization.
prior to every minimization.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
max_iterations: An `int` specifying the maximum number of L-BFGS steps
prior to every minimization.
per relax iteration. A value of 0 specifies no limit.
max_iterations: An `int` specifying the maximum number of L-BFGS steps
tolerance: kcal/mol, the energy tolerance of L-BFGS.
per relax iteration. A value of 0 specifies no limit.
The default value is the OpenMM default.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
restraint_set: The set of atoms to restrain.
The default value is the OpenMM default.
max_attempts: The maximum number of minimization attempts per iteration.
restraint_set: The set of atoms to restrain.
checks: Whether to perform cleaning checks.
max_attempts: The maximum number of minimization attempts per iteration.
exclude_residues: An optional list of zero-indexed residues to exclude from
checks: Whether to perform cleaning checks.
restraints.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
out: A dictionary of output values.
Returns:
"""
out: A dictionary of output values.
"""
# `protein.to_pdb` will strip any poorly-defined residues so we need to
# perform this check before `clean_protein`.
# `protein.to_pdb` will strip any poorly-defined residues so we need to
_check_residues_are_well_defined
(
prot
)
# perform this check before `clean_protein`.
pdb_string
=
clean_protein
(
prot
,
checks
=
checks
)
_check_residues_are_well_defined
(
prot
)
pdb_string
=
clean_protein
(
prot
,
checks
=
checks
)
exclude_residues
=
exclude_residues
or
[]
exclude_residues
=
set
(
exclude_residues
)
exclude_residues
=
exclude_residues
or
[]
violations
=
np
.
inf
exclude_residues
=
set
(
exclude_residues
)
iteration
=
0
violations
=
np
.
inf
iteration
=
0
while
violations
>
0
and
iteration
<
max_outer_iterations
:
ret
=
_run_one_iteration
(
while
violations
>
0
and
iteration
<
max_outer_iterations
:
pdb_string
=
pdb_string
,
ret
=
_run_one_iteration
(
exclude_residues
=
exclude_residues
,
pdb_string
=
pdb_string
,
max_iterations
=
max_iterations
,
exclude_residues
=
exclude_residues
,
tolerance
=
tolerance
,
max_iterations
=
max_iterations
,
stiffness
=
stiffness
,
tolerance
=
tolerance
,
restraint_set
=
restraint_set
,
stiffness
=
stiffness
,
max_attempts
=
max_attempts
)
restraint_set
=
restraint_set
,
prot
=
protein
.
from_pdb_string
(
ret
[
"min_pdb"
])
max_attempts
=
max_attempts
,
if
place_hydrogens_every_iteration
:
)
pdb_string
=
clean_protein
(
prot
,
checks
=
True
)
prot
=
protein
.
from_pdb_string
(
ret
[
"min_pdb"
])
else
:
if
place_hydrogens_every_iteration
:
pdb_string
=
ret
[
"min_pdb"
]
pdb_string
=
clean_protein
(
prot
,
checks
=
True
)
ret
.
update
(
get_violation_metrics
(
prot
))
else
:
ret
.
update
({
pdb_string
=
ret
[
"min_pdb"
]
"num_exclusions"
:
len
(
exclude_residues
),
ret
.
update
(
get_violation_metrics
(
prot
))
"iteration"
:
iteration
,
ret
.
update
(
})
{
violations
=
ret
[
"violations_per_residue"
]
"num_exclusions"
:
len
(
exclude_residues
),
exclude_residues
=
exclude_residues
.
union
(
ret
[
"residue_violations"
])
"iteration"
:
iteration
,
}
logging
.
info
(
"Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
)
"num residue violations %d num residue exclusions %d "
,
violations
=
ret
[
"violations_per_residue"
]
ret
[
"einit"
],
ret
[
"efinal"
],
ret
[
"opt_time"
],
exclude_residues
=
exclude_residues
.
union
(
ret
[
"residue_violations"
])
ret
[
"num_residue_violations"
],
ret
[
"num_exclusions"
])
iteration
+=
1
logging
.
info
(
return
ret
"Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
"num residue violations %d num residue exclusions %d "
,
ret
[
"einit"
],
def
get_initial_energies
(
pdb_strs
:
Sequence
[
str
],
ret
[
"efinal"
],
stiffness
:
float
=
0.0
,
ret
[
"opt_time"
],
restraint_set
:
str
=
"non_hydrogen"
,
ret
[
"num_residue_violations"
],
exclude_residues
:
Optional
[
Sequence
[
int
]]
=
None
):
ret
[
"num_exclusions"
],
"""Returns initial potential energies for a sequence of PDBs.
)
iteration
+=
1
Assumes the input PDBs are ready for minimization, and all have the same
return
ret
topology.
Allows time to be saved by not pdbfixing / rebuilding the system.
def
get_initial_energies
(
Args:
pdb_strs
:
Sequence
[
str
],
pdb_strs: List of PDB strings.
stiffness
:
float
=
0.0
,
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
restraint_set
:
str
=
"non_hydrogen"
,
potential.
exclude_residues
:
Optional
[
Sequence
[
int
]]
=
None
,
restraint_set: Which atom types to restrain.
):
exclude_residues: An optional list of zero-indexed residues to exclude from
"""Returns initial potential energies for a sequence of PDBs.
restraints.
Assumes the input PDBs are ready for minimization, and all have the same
Returns:
topology.
A list of initial energies in the same order as pdb_strs.
Allows time to be saved by not pdbfixing / rebuilding the system.
"""
exclude_residues
=
exclude_residues
or
[]
Args:
pdb_strs: List of PDB strings.
openmm_pdbs
=
[
openmm_app
.
PDBFile
(
PdbStructure
(
io
.
StringIO
(
p
)))
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
for
p
in
pdb_strs
]
potential.
force_field
=
openmm_app
.
ForceField
(
"amber99sb.xml"
)
restraint_set: Which atom types to restrain.
system
=
force_field
.
createSystem
(
openmm_pdbs
[
0
].
topology
,
exclude_residues: An optional list of zero-indexed residues to exclude from
constraints
=
openmm_app
.
HBonds
)
restraints.
stiffness
=
stiffness
*
ENERGY
/
(
LENGTH
**
2
)
if
stiffness
>
0
*
ENERGY
/
(
LENGTH
**
2
):
Returns:
_add_restraints
(
system
,
openmm_pdbs
[
0
],
stiffness
,
restraint_set
,
A list of initial energies in the same order as pdb_strs.
exclude_residues
)
"""
simulation
=
openmm_app
.
Simulation
(
openmm_pdbs
[
0
].
topology
,
exclude_residues
=
exclude_residues
or
[]
system
,
openmm
.
LangevinIntegrator
(
0
,
0.01
,
0.0
),
openmm_pdbs
=
[
openmm
.
Platform
.
getPlatformByName
(
"CPU"
))
openmm_app
.
PDBFile
(
PdbStructure
(
io
.
StringIO
(
p
)))
for
p
in
pdb_strs
energies
=
[]
]
for
pdb
in
openmm_pdbs
:
force_field
=
openmm_app
.
ForceField
(
"amber99sb.xml"
)
try
:
system
=
force_field
.
createSystem
(
simulation
.
context
.
setPositions
(
pdb
.
positions
)
openmm_pdbs
[
0
].
topology
,
constraints
=
openmm_app
.
HBonds
state
=
simulation
.
context
.
getState
(
getEnergy
=
True
)
)
energies
.
append
(
state
.
getPotentialEnergy
().
value_in_unit
(
ENERGY
))
stiffness
=
stiffness
*
ENERGY
/
(
LENGTH
**
2
)
except
Exception
as
e
:
# pylint: disable=broad-except
if
stiffness
>
0
*
ENERGY
/
(
LENGTH
**
2
):
logging
.
error
(
"Error getting initial energy, returning large value %s"
,
e
)
_add_restraints
(
energies
.
append
(
unit
.
Quantity
(
1e20
,
ENERGY
))
system
,
openmm_pdbs
[
0
],
stiffness
,
restraint_set
,
exclude_residues
return
energies
)
simulation
=
openmm_app
.
Simulation
(
openmm_pdbs
[
0
].
topology
,
system
,
openmm
.
LangevinIntegrator
(
0
,
0.01
,
0.0
),
openmm
.
Platform
.
getPlatformByName
(
"CPU"
),
)
energies
=
[]
for
pdb
in
openmm_pdbs
:
try
:
simulation
.
context
.
setPositions
(
pdb
.
positions
)
state
=
simulation
.
context
.
getState
(
getEnergy
=
True
)
energies
.
append
(
state
.
getPotentialEnergy
().
value_in_unit
(
ENERGY
))
except
Exception
as
e
:
# pylint: disable=broad-except
logging
.
error
(
"Error getting initial energy, returning large value %s"
,
e
)
energies
.
append
(
unit
.
Quantity
(
1e20
,
ENERGY
))
return
energies
openfold/np/relax/cleanup.py
View file @
07e64267
...
@@ -25,103 +25,107 @@ from simtk.openmm.app import element
...
@@ -25,103 +25,107 @@ from simtk.openmm.app import element
def
fix_pdb
(
pdbfile
,
alterations_info
):
def
fix_pdb
(
pdbfile
,
alterations_info
):
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
1) Replaces nonstandard residues.
1) Replaces nonstandard residues.
2) Removes heterogens (non protein residues) including water.
2) Removes heterogens (non protein residues) including water.
3) Adds missing residues and missing atoms within existing residues.
3) Adds missing residues and missing atoms within existing residues.
4) Adds hydrogens assuming pH=7.0.
4) Adds hydrogens assuming pH=7.0.
5) KeepIds is currently true, so the fixer must keep the existing chain and
5) KeepIds is currently true, so the fixer must keep the existing chain and
residue identifiers. This will fail for some files in wider PDB that have
residue identifiers. This will fail for some files in wider PDB that have
invalid IDs.
invalid IDs.
Args:
Args:
pdbfile: Input PDB file handle.
pdbfile: Input PDB file handle.
alterations_info: A dict that will store details of changes made.
alterations_info: A dict that will store details of changes made.
Returns:
Returns:
A PDB string representing the fixed structure.
A PDB string representing the fixed structure.
"""
"""
fixer
=
pdbfixer
.
PDBFixer
(
pdbfile
=
pdbfile
)
fixer
=
pdbfixer
.
PDBFixer
(
pdbfile
=
pdbfile
)
fixer
.
findNonstandardResidues
()
fixer
.
findNonstandardResidues
()
alterations_info
[
'nonstandard_residues'
]
=
fixer
.
nonstandardResidues
alterations_info
[
"nonstandard_residues"
]
=
fixer
.
nonstandardResidues
fixer
.
replaceNonstandardResidues
()
fixer
.
replaceNonstandardResidues
()
_remove_heterogens
(
fixer
,
alterations_info
,
keep_water
=
False
)
_remove_heterogens
(
fixer
,
alterations_info
,
keep_water
=
False
)
fixer
.
findMissingResidues
()
fixer
.
findMissingResidues
()
alterations_info
[
'missing_residues'
]
=
fixer
.
missingResidues
alterations_info
[
"missing_residues"
]
=
fixer
.
missingResidues
fixer
.
findMissingAtoms
()
fixer
.
findMissingAtoms
()
alterations_info
[
'missing_heavy_atoms'
]
=
fixer
.
missingAtoms
alterations_info
[
"missing_heavy_atoms"
]
=
fixer
.
missingAtoms
alterations_info
[
'missing_terminals'
]
=
fixer
.
missingTerminals
alterations_info
[
"missing_terminals"
]
=
fixer
.
missingTerminals
fixer
.
addMissingAtoms
(
seed
=
0
)
fixer
.
addMissingAtoms
(
seed
=
0
)
fixer
.
addMissingHydrogens
()
fixer
.
addMissingHydrogens
()
out_handle
=
io
.
StringIO
()
out_handle
=
io
.
StringIO
()
app
.
PDBFile
.
writeFile
(
fixer
.
topology
,
fixer
.
positions
,
out_handle
,
app
.
PDBFile
.
writeFile
(
keepIds
=
True
)
fixer
.
topology
,
fixer
.
positions
,
out_handle
,
keepIds
=
True
return
out_handle
.
getvalue
()
)
return
out_handle
.
getvalue
()
def
clean_structure
(
pdb_structure
,
alterations_info
):
def
clean_structure
(
pdb_structure
,
alterations_info
):
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
Args:
Args:
pdb_structure: An OpenMM structure to modify and fix.
pdb_structure: An OpenMM structure to modify and fix.
alterations_info: A dict that will store details of changes made.
alterations_info: A dict that will store details of changes made.
"""
"""
_replace_met_se
(
pdb_structure
,
alterations_info
)
_replace_met_se
(
pdb_structure
,
alterations_info
)
_remove_chains_of_length_one
(
pdb_structure
,
alterations_info
)
_remove_chains_of_length_one
(
pdb_structure
,
alterations_info
)
def
_remove_heterogens
(
fixer
,
alterations_info
,
keep_water
):
def
_remove_heterogens
(
fixer
,
alterations_info
,
keep_water
):
"""Removes the residues that Pdbfixer considers to be heterogens.
"""Removes the residues that Pdbfixer considers to be heterogens.
Args:
Args:
fixer: A Pdbfixer instance.
fixer: A Pdbfixer instance.
alterations_info: A dict that will store details of changes made.
alterations_info: A dict that will store details of changes made.
keep_water: If True, water (HOH) is not considered to be a heterogen.
keep_water: If True, water (HOH) is not considered to be a heterogen.
"""
"""
initial_resnames
=
set
()
initial_resnames
=
set
()
for
chain
in
fixer
.
topology
.
chains
():
for
chain
in
fixer
.
topology
.
chains
():
for
residue
in
chain
.
residues
():
for
residue
in
chain
.
residues
():
initial_resnames
.
add
(
residue
.
name
)
initial_resnames
.
add
(
residue
.
name
)
fixer
.
removeHeterogens
(
keepWater
=
keep_water
)
fixer
.
removeHeterogens
(
keepWater
=
keep_water
)
final_resnames
=
set
()
final_resnames
=
set
()
for
chain
in
fixer
.
topology
.
chains
():
for
chain
in
fixer
.
topology
.
chains
():
for
residue
in
chain
.
residues
():
for
residue
in
chain
.
residues
():
final_resnames
.
add
(
residue
.
name
)
final_resnames
.
add
(
residue
.
name
)
alterations_info
[
'removed_heterogens'
]
=
(
alterations_info
[
"removed_heterogens"
]
=
initial_resnames
.
difference
(
initial_resnames
.
difference
(
final_resnames
))
final_resnames
)
def
_replace_met_se
(
pdb_structure
,
alterations_info
):
def
_replace_met_se
(
pdb_structure
,
alterations_info
):
"""Replace the Se in any MET residues that were not marked as modified."""
"""Replace the Se in any MET residues that were not marked as modified."""
modified_met_residues
=
[]
modified_met_residues
=
[]
for
res
in
pdb_structure
.
iter_residues
():
for
res
in
pdb_structure
.
iter_residues
():
name
=
res
.
get_name_with_spaces
().
strip
()
name
=
res
.
get_name_with_spaces
().
strip
()
if
name
==
'
MET
'
:
if
name
==
"
MET
"
:
s_atom
=
res
.
get_atom
(
'
SD
'
)
s_atom
=
res
.
get_atom
(
"
SD
"
)
if
s_atom
.
element_symbol
==
'
Se
'
:
if
s_atom
.
element_symbol
==
"
Se
"
:
s_atom
.
element_symbol
=
'S'
s_atom
.
element_symbol
=
"S"
s_atom
.
element
=
element
.
get_by_symbol
(
'S'
)
s_atom
.
element
=
element
.
get_by_symbol
(
"S"
)
modified_met_residues
.
append
(
s_atom
.
residue_number
)
modified_met_residues
.
append
(
s_atom
.
residue_number
)
alterations_info
[
'
Se_in_MET
'
]
=
modified_met_residues
alterations_info
[
"
Se_in_MET
"
]
=
modified_met_residues
def
_remove_chains_of_length_one
(
pdb_structure
,
alterations_info
):
def
_remove_chains_of_length_one
(
pdb_structure
,
alterations_info
):
"""Removes chains that correspond to a single amino acid.
"""Removes chains that correspond to a single amino acid.
A single amino acid in a chain is both N and C terminus. There is no force
A single amino acid in a chain is both N and C terminus. There is no force
template for this case.
template for this case.
Args:
Args:
pdb_structure: An OpenMM pdb_structure to modify and fix.
pdb_structure: An OpenMM pdb_structure to modify and fix.
alterations_info: A dict that will store details of changes made.
alterations_info: A dict that will store details of changes made.
"""
"""
removed_chains
=
{}
removed_chains
=
{}
for
model
in
pdb_structure
.
iter_models
():
for
model
in
pdb_structure
.
iter_models
():
valid_chains
=
[
c
for
c
in
model
.
iter_chains
()
if
len
(
c
)
>
1
]
valid_chains
=
[
c
for
c
in
model
.
iter_chains
()
if
len
(
c
)
>
1
]
invalid_chain_ids
=
[
c
.
chain_id
for
c
in
model
.
iter_chains
()
if
len
(
c
)
<=
1
]
invalid_chain_ids
=
[
model
.
chains
=
valid_chains
c
.
chain_id
for
c
in
model
.
iter_chains
()
if
len
(
c
)
<=
1
for
chain_id
in
invalid_chain_ids
:
]
model
.
chains_by_id
.
pop
(
chain_id
)
model
.
chains
=
valid_chains
removed_chains
[
model
.
number
]
=
invalid_chain_ids
for
chain_id
in
invalid_chain_ids
:
alterations_info
[
'removed_chains'
]
=
removed_chains
model
.
chains_by_id
.
pop
(
chain_id
)
removed_chains
[
model
.
number
]
=
invalid_chain_ids
alterations_info
[
"removed_chains"
]
=
removed_chains
openfold/np/relax/relax.py
View file @
07e64267
...
@@ -21,60 +21,67 @@ import numpy as np
...
@@ -21,60 +21,67 @@ import numpy as np
class
AmberRelaxation
(
object
):
class
AmberRelaxation
(
object
):
"""Amber relaxation."""
"""Amber relaxation."""
def
__init__
(
self
,
def
__init__
(
*
,
self
,
max_iterations
:
int
,
*
,
tolerance
:
float
,
max_iterations
:
int
,
stiffness
:
float
,
tolerance
:
float
,
exclude_residues
:
Sequence
[
int
],
stiffness
:
float
,
max_outer_iterations
:
int
):
exclude_residues
:
Sequence
[
int
],
"""Initialize Amber Relaxer.
max_outer_iterations
:
int
):
"""Initialize Amber Relaxer.
Args:
Args:
max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
potential.
exclude_residues: Residues to exclude from per-atom restraining.
exclude_residues: Residues to exclude from per-atom restraining.
Zero-indexed.
Zero-indexed.
max_outer_iterations: Maximum number of violation-informed relax
max_outer_iterations: Maximum number of violation-informed relax
iterations. A value of 1 will run the non-iterative procedure used in
iterations. A value of 1 will run the non-iterative procedure used in
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
as soon as there are no violations, hence in most cases this causes no
as soon as there are no violations, hence in most cases this causes no
slowdown. In the worst case we do 20 outer iterations.
slowdown. In the worst case we do 20 outer iterations.
"""
"""
self
.
_max_iterations
=
max_iterations
self
.
_max_iterations
=
max_iterations
self
.
_tolerance
=
tolerance
self
.
_tolerance
=
tolerance
self
.
_stiffness
=
stiffness
self
.
_stiffness
=
stiffness
self
.
_exclude_residues
=
exclude_residues
self
.
_exclude_residues
=
exclude_residues
self
.
_max_outer_iterations
=
max_outer_iterations
self
.
_max_outer_iterations
=
max_outer_iterations
def
process
(
self
,
*
,
def
process
(
prot
:
protein
.
Protein
)
->
Tuple
[
str
,
Dict
[
str
,
Any
],
np
.
ndarray
]:
self
,
*
,
prot
:
protein
.
Protein
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
)
->
Tuple
[
str
,
Dict
[
str
,
Any
],
np
.
ndarray
]:
out
=
amber_minimize
.
run_pipeline
(
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
prot
=
prot
,
max_iterations
=
self
.
_max_iterations
,
out
=
amber_minimize
.
run_pipeline
(
tolerance
=
self
.
_tolerance
,
stiffness
=
self
.
_stiffness
,
prot
=
prot
,
exclude_residues
=
self
.
_exclude_residues
,
max_iterations
=
self
.
_max_iterations
,
max_outer_iterations
=
self
.
_max_outer_iterations
)
tolerance
=
self
.
_tolerance
,
min_pos
=
out
[
'pos'
]
stiffness
=
self
.
_stiffness
,
start_pos
=
out
[
'posinit'
]
exclude_residues
=
self
.
_exclude_residues
,
rmsd
=
np
.
sqrt
(
np
.
sum
((
start_pos
-
min_pos
)
**
2
)
/
start_pos
.
shape
[
0
])
max_outer_iterations
=
self
.
_max_outer_iterations
,
debug_data
=
{
)
'initial_energy'
:
out
[
'einit'
],
min_pos
=
out
[
"pos"
]
'final_energy'
:
out
[
'efinal'
],
start_pos
=
out
[
"posinit"
]
'attempts'
:
out
[
'min_attempts'
],
rmsd
=
np
.
sqrt
(
np
.
sum
((
start_pos
-
min_pos
)
**
2
)
/
start_pos
.
shape
[
0
])
'rmsd'
:
rmsd
debug_data
=
{
}
"initial_energy"
:
out
[
"einit"
],
pdb_str
=
amber_minimize
.
clean_protein
(
prot
)
"final_energy"
:
out
[
"efinal"
],
min_pdb
=
utils
.
overwrite_pdb_coordinates
(
pdb_str
,
min_pos
)
"attempts"
:
out
[
"min_attempts"
],
min_pdb
=
utils
.
overwrite_b_factors
(
min_pdb
,
prot
.
b_factors
)
"rmsd"
:
rmsd
,
utils
.
assert_equal_nonterminal_atom_types
(
}
protein
.
from_pdb_string
(
min_pdb
).
atom_mask
,
pdb_str
=
amber_minimize
.
clean_protein
(
prot
)
prot
.
atom_mask
)
min_pdb
=
utils
.
overwrite_pdb_coordinates
(
pdb_str
,
min_pos
)
violations
=
out
[
'structural_violations'
][
min_pdb
=
utils
.
overwrite_b_factors
(
min_pdb
,
prot
.
b_factors
)
'total_per_residue_violations_mask'
]
utils
.
assert_equal_nonterminal_atom_types
(
return
min_pdb
,
debug_data
,
violations
protein
.
from_pdb_string
(
min_pdb
).
atom_mask
,
prot
.
atom_mask
)
violations
=
out
[
"structural_violations"
][
"total_per_residue_violations_mask"
]
return
min_pdb
,
debug_data
,
violations
openfold/np/relax/utils.py
View file @
07e64267
...
@@ -23,59 +23,64 @@ from simtk.openmm.app.internal.pdbstructure import PdbStructure
...
@@ -23,59 +23,64 @@ from simtk.openmm.app.internal.pdbstructure import PdbStructure
def
overwrite_pdb_coordinates
(
pdb_str
:
str
,
pos
)
->
str
:
def
overwrite_pdb_coordinates
(
pdb_str
:
str
,
pos
)
->
str
:
pdb_file
=
io
.
StringIO
(
pdb_str
)
pdb_file
=
io
.
StringIO
(
pdb_str
)
structure
=
PdbStructure
(
pdb_file
)
structure
=
PdbStructure
(
pdb_file
)
topology
=
openmm_app
.
PDBFile
(
structure
).
getTopology
()
topology
=
openmm_app
.
PDBFile
(
structure
).
getTopology
()
with
io
.
StringIO
()
as
f
:
with
io
.
StringIO
()
as
f
:
openmm_app
.
PDBFile
.
writeFile
(
topology
,
pos
,
f
)
openmm_app
.
PDBFile
.
writeFile
(
topology
,
pos
,
f
)
return
f
.
getvalue
()
return
f
.
getvalue
()
def
overwrite_b_factors
(
pdb_str
:
str
,
bfactors
:
np
.
ndarray
)
->
str
:
def
overwrite_b_factors
(
pdb_str
:
str
,
bfactors
:
np
.
ndarray
)
->
str
:
"""Overwrites the B-factors in pdb_str with contents of bfactors array.
"""Overwrites the B-factors in pdb_str with contents of bfactors array.
Args:
Args:
pdb_str: An input PDB string.
pdb_str: An input PDB string.
bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
B-factors are per residue; i.e. that the nonzero entries are identical in
B-factors are per residue; i.e. that the nonzero entries are identical in
[0, i, :].
[0, i, :].
Returns:
Returns:
A new PDB string with the B-factors replaced.
A new PDB string with the B-factors replaced.
"""
"""
if
bfactors
.
shape
[
-
1
]
!=
residue_constants
.
atom_type_num
:
if
bfactors
.
shape
[
-
1
]
!=
residue_constants
.
atom_type_num
:
raise
ValueError
(
raise
ValueError
(
f
'Invalid final dimension size for bfactors:
{
bfactors
.
shape
[
-
1
]
}
.'
)
f
"Invalid final dimension size for bfactors:
{
bfactors
.
shape
[
-
1
]
}
."
)
parser
=
PDB
.
PDBParser
(
QUIET
=
True
)
parser
=
PDB
.
PDBParser
(
QUIET
=
True
)
handle
=
io
.
StringIO
(
pdb_str
)
handle
=
io
.
StringIO
(
pdb_str
)
structure
=
parser
.
get_structure
(
''
,
handle
)
structure
=
parser
.
get_structure
(
""
,
handle
)
curr_resid
=
(
''
,
''
,
''
)
curr_resid
=
(
""
,
""
,
""
)
idx
=
-
1
idx
=
-
1
for
atom
in
structure
.
get_atoms
():
for
atom
in
structure
.
get_atoms
():
atom_resid
=
atom
.
parent
.
get_id
()
atom_resid
=
atom
.
parent
.
get_id
()
if
atom_resid
!=
curr_resid
:
if
atom_resid
!=
curr_resid
:
idx
+=
1
idx
+=
1
if
idx
>=
bfactors
.
shape
[
0
]:
if
idx
>=
bfactors
.
shape
[
0
]:
raise
ValueError
(
'Index into bfactors exceeds number of residues. '
raise
ValueError
(
'B-factors shape: {shape}, idx: {idx}.'
)
"Index into bfactors exceeds number of residues. "
curr_resid
=
atom_resid
"B-factors shape: {shape}, idx: {idx}."
atom
.
bfactor
=
bfactors
[
idx
,
residue_constants
.
atom_order
[
'CA'
]]
)
curr_resid
=
atom_resid
atom
.
bfactor
=
bfactors
[
idx
,
residue_constants
.
atom_order
[
"CA"
]]
new_pdb
=
io
.
StringIO
()
new_pdb
=
io
.
StringIO
()
pdb_io
=
PDB
.
PDBIO
()
pdb_io
=
PDB
.
PDBIO
()
pdb_io
.
set_structure
(
structure
)
pdb_io
.
set_structure
(
structure
)
pdb_io
.
save
(
new_pdb
)
pdb_io
.
save
(
new_pdb
)
return
new_pdb
.
getvalue
()
return
new_pdb
.
getvalue
()
def
assert_equal_nonterminal_atom_types
(
def
assert_equal_nonterminal_atom_types
(
atom_mask
:
np
.
ndarray
,
ref_atom_mask
:
np
.
ndarray
):
atom_mask
:
np
.
ndarray
,
ref_atom_mask
:
np
.
ndarray
"""Checks that pre- and post-minimized proteins have same atom set."""
):
# Ignore any terminal OXT atoms which may have been added by minimization.
"""Checks that pre- and post-minimized proteins have same atom set."""
oxt
=
residue_constants
.
atom_order
[
'OXT'
]
# Ignore any terminal OXT atoms which may have been added by minimization.
no_oxt_mask
=
np
.
ones
(
shape
=
atom_mask
.
shape
,
dtype
=
np
.
bool
)
oxt
=
residue_constants
.
atom_order
[
"OXT"
]
no_oxt_mask
[...,
oxt
]
=
False
no_oxt_mask
=
np
.
ones
(
shape
=
atom_mask
.
shape
,
dtype
=
np
.
bool
)
np
.
testing
.
assert_almost_equal
(
ref_atom_mask
[
no_oxt_mask
],
no_oxt_mask
[...,
oxt
]
=
False
atom_mask
[
no_oxt_mask
])
np
.
testing
.
assert_almost_equal
(
ref_atom_mask
[
no_oxt_mask
],
atom_mask
[
no_oxt_mask
]
)
openfold/np/residue_constants.py
View file @
07e64267
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
...
@@ -32,32 +32,49 @@ ca_ca = 3.80209737096
...
@@ -32,32 +32,49 @@ ca_ca = 3.80209737096
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
# chi angles so their chi angle lists are empty.
chi_angles_atoms
=
{
chi_angles_atoms
=
{
'
ALA
'
:
[],
"
ALA
"
:
[],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
'ARG'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'CD'
],
"ARG"
:
[
[
'CB'
,
'CG'
,
'CD'
,
'NE'
],
[
'CG'
,
'CD'
,
'NE'
,
'CZ'
]],
[
"N"
,
"CA"
,
"CB"
,
"CG"
],
'ASN'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'OD1'
]],
[
"CA"
,
"CB"
,
"CG"
,
"CD"
],
'ASP'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'OD1'
]],
[
"CB"
,
"CG"
,
"CD"
,
"NE"
],
'CYS'
:
[[
'N'
,
'CA'
,
'CB'
,
'SG'
]],
[
"CG"
,
"CD"
,
"NE"
,
"CZ"
],
'GLN'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'CD'
],
],
[
'CB'
,
'CG'
,
'CD'
,
'OE1'
]],
"ASN"
:
[[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
"CA"
,
"CB"
,
"CG"
,
"OD1"
]],
'GLU'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'CD'
],
"ASP"
:
[[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
"CA"
,
"CB"
,
"CG"
,
"OD1"
]],
[
'CB'
,
'CG'
,
'CD'
,
'OE1'
]],
"CYS"
:
[[
"N"
,
"CA"
,
"CB"
,
"SG"
]],
'GLY'
:
[],
"GLN"
:
[
'HIS'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'ND1'
]],
[
"N"
,
"CA"
,
"CB"
,
"CG"
],
'ILE'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG1'
],
[
'CA'
,
'CB'
,
'CG1'
,
'CD1'
]],
[
"CA"
,
"CB"
,
"CG"
,
"CD"
],
'LEU'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'CD1'
]],
[
"CB"
,
"CG"
,
"CD"
,
"OE1"
],
'LYS'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'CD'
],
],
[
'CB'
,
'CG'
,
'CD'
,
'CE'
],
[
'CG'
,
'CD'
,
'CE'
,
'NZ'
]],
"GLU"
:
[
'MET'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'SD'
],
[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
'CB'
,
'CG'
,
'SD'
,
'CE'
]],
[
"CA"
,
"CB"
,
"CG"
,
"CD"
],
'PHE'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'CD1'
]],
[
"CB"
,
"CG"
,
"CD"
,
"OE1"
],
'PRO'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'CD'
]],
],
'SER'
:
[[
'N'
,
'CA'
,
'CB'
,
'OG'
]],
"GLY"
:
[],
'THR'
:
[[
'N'
,
'CA'
,
'CB'
,
'OG1'
]],
"HIS"
:
[[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
"CA"
,
"CB"
,
"CG"
,
"ND1"
]],
'TRP'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'CD1'
]],
"ILE"
:
[[
"N"
,
"CA"
,
"CB"
,
"CG1"
],
[
"CA"
,
"CB"
,
"CG1"
,
"CD1"
]],
'TYR'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG'
],
[
'CA'
,
'CB'
,
'CG'
,
'CD1'
]],
"LEU"
:
[[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
"CA"
,
"CB"
,
"CG"
,
"CD1"
]],
'VAL'
:
[[
'N'
,
'CA'
,
'CB'
,
'CG1'
]],
"LYS"
:
[
[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
"CA"
,
"CB"
,
"CG"
,
"CD"
],
[
"CB"
,
"CG"
,
"CD"
,
"CE"
],
[
"CG"
,
"CD"
,
"CE"
,
"NZ"
],
],
"MET"
:
[
[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
"CA"
,
"CB"
,
"CG"
,
"SD"
],
[
"CB"
,
"CG"
,
"SD"
,
"CE"
],
],
"PHE"
:
[[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
"CA"
,
"CB"
,
"CG"
,
"CD1"
]],
"PRO"
:
[[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
"CA"
,
"CB"
,
"CG"
,
"CD"
]],
"SER"
:
[[
"N"
,
"CA"
,
"CB"
,
"OG"
]],
"THR"
:
[[
"N"
,
"CA"
,
"CB"
,
"OG1"
]],
"TRP"
:
[[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
"CA"
,
"CB"
,
"CG"
,
"CD1"
]],
"TYR"
:
[[
"N"
,
"CA"
,
"CB"
,
"CG"
],
[
"CA"
,
"CB"
,
"CG"
,
"CD1"
]],
"VAL"
:
[[
"N"
,
"CA"
,
"CB"
,
"CG1"
]],
}
}
# If chi angles given in fixed-length array, this matrix determines how to mask
# If chi angles given in fixed-length array, this matrix determines how to mask
...
@@ -124,240 +141,266 @@ chi_pi_periodic = [
...
@@ -124,240 +141,266 @@ chi_pi_periodic = [
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# format: [atomname, group_idx, rel_position]
# format: [atomname, group_idx, rel_position]
rigid_group_atom_positions
=
{
rigid_group_atom_positions
=
{
'
ALA
'
:
[
"
ALA
"
:
[
[
'N'
,
0
,
(
-
0.525
,
1.363
,
0.000
)],
[
"N"
,
0
,
(
-
0.525
,
1.363
,
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.526
,
-
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.526
,
-
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.529
,
-
0.774
,
-
1.205
)],
[
"
CB
"
,
0
,
(
-
0.529
,
-
0.774
,
-
1.205
)],
[
'O'
,
3
,
(
0.627
,
1.062
,
0.000
)],
[
"O"
,
3
,
(
0.627
,
1.062
,
0.000
)],
],
],
'
ARG
'
:
[
"
ARG
"
:
[
[
'N'
,
0
,
(
-
0.524
,
1.362
,
-
0.000
)],
[
"N"
,
0
,
(
-
0.524
,
1.362
,
-
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.525
,
-
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.525
,
-
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.524
,
-
0.778
,
-
1.209
)],
[
"
CB
"
,
0
,
(
-
0.524
,
-
0.778
,
-
1.209
)],
[
'O'
,
3
,
(
0.626
,
1.062
,
0.000
)],
[
"O"
,
3
,
(
0.626
,
1.062
,
0.000
)],
[
'
CG
'
,
4
,
(
0.616
,
1.390
,
-
0.000
)],
[
"
CG
"
,
4
,
(
0.616
,
1.390
,
-
0.000
)],
[
'
CD
'
,
5
,
(
0.564
,
1.414
,
0.000
)],
[
"
CD
"
,
5
,
(
0.564
,
1.414
,
0.000
)],
[
'
NE
'
,
6
,
(
0.539
,
1.357
,
-
0.000
)],
[
"
NE
"
,
6
,
(
0.539
,
1.357
,
-
0.000
)],
[
'
NH1
'
,
7
,
(
0.206
,
2.301
,
0.000
)],
[
"
NH1
"
,
7
,
(
0.206
,
2.301
,
0.000
)],
[
'
NH2
'
,
7
,
(
2.078
,
0.978
,
-
0.000
)],
[
"
NH2
"
,
7
,
(
2.078
,
0.978
,
-
0.000
)],
[
'
CZ
'
,
7
,
(
0.758
,
1.093
,
-
0.000
)],
[
"
CZ
"
,
7
,
(
0.758
,
1.093
,
-
0.000
)],
],
],
'
ASN
'
:
[
"
ASN
"
:
[
[
'N'
,
0
,
(
-
0.536
,
1.357
,
0.000
)],
[
"N"
,
0
,
(
-
0.536
,
1.357
,
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.526
,
-
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.526
,
-
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.531
,
-
0.787
,
-
1.200
)],
[
"
CB
"
,
0
,
(
-
0.531
,
-
0.787
,
-
1.200
)],
[
'O'
,
3
,
(
0.625
,
1.062
,
0.000
)],
[
"O"
,
3
,
(
0.625
,
1.062
,
0.000
)],
[
'
CG
'
,
4
,
(
0.584
,
1.399
,
0.000
)],
[
"
CG
"
,
4
,
(
0.584
,
1.399
,
0.000
)],
[
'
ND2
'
,
5
,
(
0.593
,
-
1.188
,
0.001
)],
[
"
ND2
"
,
5
,
(
0.593
,
-
1.188
,
0.001
)],
[
'
OD1
'
,
5
,
(
0.633
,
1.059
,
0.000
)],
[
"
OD1
"
,
5
,
(
0.633
,
1.059
,
0.000
)],
],
],
'
ASP
'
:
[
"
ASP
"
:
[
[
'N'
,
0
,
(
-
0.525
,
1.362
,
-
0.000
)],
[
"N"
,
0
,
(
-
0.525
,
1.362
,
-
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.527
,
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.527
,
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.526
,
-
0.778
,
-
1.208
)],
[
"
CB
"
,
0
,
(
-
0.526
,
-
0.778
,
-
1.208
)],
[
'O'
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
'
CG
'
,
4
,
(
0.593
,
1.398
,
-
0.000
)],
[
"
CG
"
,
4
,
(
0.593
,
1.398
,
-
0.000
)],
[
'
OD1
'
,
5
,
(
0.610
,
1.091
,
0.000
)],
[
"
OD1
"
,
5
,
(
0.610
,
1.091
,
0.000
)],
[
'
OD2
'
,
5
,
(
0.592
,
-
1.101
,
-
0.003
)],
[
"
OD2
"
,
5
,
(
0.592
,
-
1.101
,
-
0.003
)],
],
],
'
CYS
'
:
[
"
CYS
"
:
[
[
'N'
,
0
,
(
-
0.522
,
1.362
,
-
0.000
)],
[
"N"
,
0
,
(
-
0.522
,
1.362
,
-
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.524
,
0.000
,
0.000
)],
[
"C"
,
0
,
(
1.524
,
0.000
,
0.000
)],
[
'
CB
'
,
0
,
(
-
0.519
,
-
0.773
,
-
1.212
)],
[
"
CB
"
,
0
,
(
-
0.519
,
-
0.773
,
-
1.212
)],
[
'O'
,
3
,
(
0.625
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.625
,
1.062
,
-
0.000
)],
[
'
SG
'
,
4
,
(
0.728
,
1.653
,
0.000
)],
[
"
SG
"
,
4
,
(
0.728
,
1.653
,
0.000
)],
],
],
'
GLN
'
:
[
"
GLN
"
:
[
[
'N'
,
0
,
(
-
0.526
,
1.361
,
-
0.000
)],
[
"N"
,
0
,
(
-
0.526
,
1.361
,
-
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.526
,
0.000
,
0.000
)],
[
"C"
,
0
,
(
1.526
,
0.000
,
0.000
)],
[
'
CB
'
,
0
,
(
-
0.525
,
-
0.779
,
-
1.207
)],
[
"
CB
"
,
0
,
(
-
0.525
,
-
0.779
,
-
1.207
)],
[
'O'
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
'
CG
'
,
4
,
(
0.615
,
1.393
,
0.000
)],
[
"
CG
"
,
4
,
(
0.615
,
1.393
,
0.000
)],
[
'
CD
'
,
5
,
(
0.587
,
1.399
,
-
0.000
)],
[
"
CD
"
,
5
,
(
0.587
,
1.399
,
-
0.000
)],
[
'
NE2
'
,
6
,
(
0.593
,
-
1.189
,
-
0.001
)],
[
"
NE2
"
,
6
,
(
0.593
,
-
1.189
,
-
0.001
)],
[
'
OE1
'
,
6
,
(
0.634
,
1.060
,
0.000
)],
[
"
OE1
"
,
6
,
(
0.634
,
1.060
,
0.000
)],
],
],
'
GLU
'
:
[
"
GLU
"
:
[
[
'N'
,
0
,
(
-
0.528
,
1.361
,
0.000
)],
[
"N"
,
0
,
(
-
0.528
,
1.361
,
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.526
,
-
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.526
,
-
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.526
,
-
0.781
,
-
1.207
)],
[
"
CB
"
,
0
,
(
-
0.526
,
-
0.781
,
-
1.207
)],
[
'O'
,
3
,
(
0.626
,
1.062
,
0.000
)],
[
"O"
,
3
,
(
0.626
,
1.062
,
0.000
)],
[
'
CG
'
,
4
,
(
0.615
,
1.392
,
0.000
)],
[
"
CG
"
,
4
,
(
0.615
,
1.392
,
0.000
)],
[
'
CD
'
,
5
,
(
0.600
,
1.397
,
0.000
)],
[
"
CD
"
,
5
,
(
0.600
,
1.397
,
0.000
)],
[
'
OE1
'
,
6
,
(
0.607
,
1.095
,
-
0.000
)],
[
"
OE1
"
,
6
,
(
0.607
,
1.095
,
-
0.000
)],
[
'
OE2
'
,
6
,
(
0.589
,
-
1.104
,
-
0.001
)],
[
"
OE2
"
,
6
,
(
0.589
,
-
1.104
,
-
0.001
)],
],
],
'
GLY
'
:
[
"
GLY
"
:
[
[
'N'
,
0
,
(
-
0.572
,
1.337
,
0.000
)],
[
"N"
,
0
,
(
-
0.572
,
1.337
,
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.517
,
-
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.517
,
-
0.000
,
-
0.000
)],
[
'O'
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
],
],
'
HIS
'
:
[
"
HIS
"
:
[
[
'N'
,
0
,
(
-
0.527
,
1.360
,
0.000
)],
[
"N"
,
0
,
(
-
0.527
,
1.360
,
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.525
,
0.000
,
0.000
)],
[
"C"
,
0
,
(
1.525
,
0.000
,
0.000
)],
[
'
CB
'
,
0
,
(
-
0.525
,
-
0.778
,
-
1.208
)],
[
"
CB
"
,
0
,
(
-
0.525
,
-
0.778
,
-
1.208
)],
[
'O'
,
3
,
(
0.625
,
1.063
,
0.000
)],
[
"O"
,
3
,
(
0.625
,
1.063
,
0.000
)],
[
'
CG
'
,
4
,
(
0.600
,
1.370
,
-
0.000
)],
[
"
CG
"
,
4
,
(
0.600
,
1.370
,
-
0.000
)],
[
'
CD2
'
,
5
,
(
0.889
,
-
1.021
,
0.003
)],
[
"
CD2
"
,
5
,
(
0.889
,
-
1.021
,
0.003
)],
[
'
ND1
'
,
5
,
(
0.744
,
1.160
,
-
0.000
)],
[
"
ND1
"
,
5
,
(
0.744
,
1.160
,
-
0.000
)],
[
'
CE1
'
,
5
,
(
2.030
,
0.851
,
0.002
)],
[
"
CE1
"
,
5
,
(
2.030
,
0.851
,
0.002
)],
[
'
NE2
'
,
5
,
(
2.145
,
-
0.466
,
0.004
)],
[
"
NE2
"
,
5
,
(
2.145
,
-
0.466
,
0.004
)],
],
],
'
ILE
'
:
[
"
ILE
"
:
[
[
'N'
,
0
,
(
-
0.493
,
1.373
,
-
0.000
)],
[
"N"
,
0
,
(
-
0.493
,
1.373
,
-
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.527
,
-
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.527
,
-
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.536
,
-
0.793
,
-
1.213
)],
[
"
CB
"
,
0
,
(
-
0.536
,
-
0.793
,
-
1.213
)],
[
'O'
,
3
,
(
0.627
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.627
,
1.062
,
-
0.000
)],
[
'
CG1
'
,
4
,
(
0.534
,
1.437
,
-
0.000
)],
[
"
CG1
"
,
4
,
(
0.534
,
1.437
,
-
0.000
)],
[
'
CG2
'
,
4
,
(
0.540
,
-
0.785
,
-
1.199
)],
[
"
CG2
"
,
4
,
(
0.540
,
-
0.785
,
-
1.199
)],
[
'
CD1
'
,
5
,
(
0.619
,
1.391
,
0.000
)],
[
"
CD1
"
,
5
,
(
0.619
,
1.391
,
0.000
)],
],
],
'
LEU
'
:
[
"
LEU
"
:
[
[
'N'
,
0
,
(
-
0.520
,
1.363
,
0.000
)],
[
"N"
,
0
,
(
-
0.520
,
1.363
,
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.525
,
-
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.525
,
-
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.522
,
-
0.773
,
-
1.214
)],
[
"
CB
"
,
0
,
(
-
0.522
,
-
0.773
,
-
1.214
)],
[
'O'
,
3
,
(
0.625
,
1.063
,
-
0.000
)],
[
"O"
,
3
,
(
0.625
,
1.063
,
-
0.000
)],
[
'
CG
'
,
4
,
(
0.678
,
1.371
,
0.000
)],
[
"
CG
"
,
4
,
(
0.678
,
1.371
,
0.000
)],
[
'
CD1
'
,
5
,
(
0.530
,
1.430
,
-
0.000
)],
[
"
CD1
"
,
5
,
(
0.530
,
1.430
,
-
0.000
)],
[
'
CD2
'
,
5
,
(
0.535
,
-
0.774
,
1.200
)],
[
"
CD2
"
,
5
,
(
0.535
,
-
0.774
,
1.200
)],
],
],
'
LYS
'
:
[
"
LYS
"
:
[
[
'N'
,
0
,
(
-
0.526
,
1.362
,
-
0.000
)],
[
"N"
,
0
,
(
-
0.526
,
1.362
,
-
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.526
,
0.000
,
0.000
)],
[
"C"
,
0
,
(
1.526
,
0.000
,
0.000
)],
[
'
CB
'
,
0
,
(
-
0.524
,
-
0.778
,
-
1.208
)],
[
"
CB
"
,
0
,
(
-
0.524
,
-
0.778
,
-
1.208
)],
[
'O'
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
'
CG
'
,
4
,
(
0.619
,
1.390
,
0.000
)],
[
"
CG
"
,
4
,
(
0.619
,
1.390
,
0.000
)],
[
'
CD
'
,
5
,
(
0.559
,
1.417
,
0.000
)],
[
"
CD
"
,
5
,
(
0.559
,
1.417
,
0.000
)],
[
'
CE
'
,
6
,
(
0.560
,
1.416
,
0.000
)],
[
"
CE
"
,
6
,
(
0.560
,
1.416
,
0.000
)],
[
'
NZ
'
,
7
,
(
0.554
,
1.387
,
0.000
)],
[
"
NZ
"
,
7
,
(
0.554
,
1.387
,
0.000
)],
],
],
'
MET
'
:
[
"
MET
"
:
[
[
'N'
,
0
,
(
-
0.521
,
1.364
,
-
0.000
)],
[
"N"
,
0
,
(
-
0.521
,
1.364
,
-
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.525
,
0.000
,
0.000
)],
[
"C"
,
0
,
(
1.525
,
0.000
,
0.000
)],
[
'
CB
'
,
0
,
(
-
0.523
,
-
0.776
,
-
1.210
)],
[
"
CB
"
,
0
,
(
-
0.523
,
-
0.776
,
-
1.210
)],
[
'O'
,
3
,
(
0.625
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.625
,
1.062
,
-
0.000
)],
[
'
CG
'
,
4
,
(
0.613
,
1.391
,
-
0.000
)],
[
"
CG
"
,
4
,
(
0.613
,
1.391
,
-
0.000
)],
[
'
SD
'
,
5
,
(
0.703
,
1.695
,
0.000
)],
[
"
SD
"
,
5
,
(
0.703
,
1.695
,
0.000
)],
[
'
CE
'
,
6
,
(
0.320
,
1.786
,
-
0.000
)],
[
"
CE
"
,
6
,
(
0.320
,
1.786
,
-
0.000
)],
],
],
'
PHE
'
:
[
"
PHE
"
:
[
[
'N'
,
0
,
(
-
0.518
,
1.363
,
0.000
)],
[
"N"
,
0
,
(
-
0.518
,
1.363
,
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.524
,
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.524
,
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.525
,
-
0.776
,
-
1.212
)],
[
"
CB
"
,
0
,
(
-
0.525
,
-
0.776
,
-
1.212
)],
[
'O'
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
'
CG
'
,
4
,
(
0.607
,
1.377
,
0.000
)],
[
"
CG
"
,
4
,
(
0.607
,
1.377
,
0.000
)],
[
'
CD1
'
,
5
,
(
0.709
,
1.195
,
-
0.000
)],
[
"
CD1
"
,
5
,
(
0.709
,
1.195
,
-
0.000
)],
[
'
CD2
'
,
5
,
(
0.706
,
-
1.196
,
0.000
)],
[
"
CD2
"
,
5
,
(
0.706
,
-
1.196
,
0.000
)],
[
'
CE1
'
,
5
,
(
2.102
,
1.198
,
-
0.000
)],
[
"
CE1
"
,
5
,
(
2.102
,
1.198
,
-
0.000
)],
[
'
CE2
'
,
5
,
(
2.098
,
-
1.201
,
-
0.000
)],
[
"
CE2
"
,
5
,
(
2.098
,
-
1.201
,
-
0.000
)],
[
'
CZ
'
,
5
,
(
2.794
,
-
0.003
,
-
0.001
)],
[
"
CZ
"
,
5
,
(
2.794
,
-
0.003
,
-
0.001
)],
],
],
'
PRO
'
:
[
"
PRO
"
:
[
[
'N'
,
0
,
(
-
0.566
,
1.351
,
-
0.000
)],
[
"N"
,
0
,
(
-
0.566
,
1.351
,
-
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.527
,
-
0.000
,
0.000
)],
[
"C"
,
0
,
(
1.527
,
-
0.000
,
0.000
)],
[
'
CB
'
,
0
,
(
-
0.546
,
-
0.611
,
-
1.293
)],
[
"
CB
"
,
0
,
(
-
0.546
,
-
0.611
,
-
1.293
)],
[
'O'
,
3
,
(
0.621
,
1.066
,
0.000
)],
[
"O"
,
3
,
(
0.621
,
1.066
,
0.000
)],
[
'
CG
'
,
4
,
(
0.382
,
1.445
,
0.0
)],
[
"
CG
"
,
4
,
(
0.382
,
1.445
,
0.0
)],
# ['CD', 5, (0.427, 1.440, 0.0)],
# ['CD', 5, (0.427, 1.440, 0.0)],
[
'
CD
'
,
5
,
(
0.477
,
1.424
,
0.0
)],
# manually made angle 2 degrees larger
[
"
CD
"
,
5
,
(
0.477
,
1.424
,
0.0
)],
# manually made angle 2 degrees larger
],
],
'
SER
'
:
[
"
SER
"
:
[
[
'N'
,
0
,
(
-
0.529
,
1.360
,
-
0.000
)],
[
"N"
,
0
,
(
-
0.529
,
1.360
,
-
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.525
,
-
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.525
,
-
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.518
,
-
0.777
,
-
1.211
)],
[
"
CB
"
,
0
,
(
-
0.518
,
-
0.777
,
-
1.211
)],
[
'O'
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.626
,
1.062
,
-
0.000
)],
[
'
OG
'
,
4
,
(
0.503
,
1.325
,
0.000
)],
[
"
OG
"
,
4
,
(
0.503
,
1.325
,
0.000
)],
],
],
'
THR
'
:
[
"
THR
"
:
[
[
'N'
,
0
,
(
-
0.517
,
1.364
,
0.000
)],
[
"N"
,
0
,
(
-
0.517
,
1.364
,
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.526
,
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.526
,
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.516
,
-
0.793
,
-
1.215
)],
[
"
CB
"
,
0
,
(
-
0.516
,
-
0.793
,
-
1.215
)],
[
'O'
,
3
,
(
0.626
,
1.062
,
0.000
)],
[
"O"
,
3
,
(
0.626
,
1.062
,
0.000
)],
[
'
CG2
'
,
4
,
(
0.550
,
-
0.718
,
-
1.228
)],
[
"
CG2
"
,
4
,
(
0.550
,
-
0.718
,
-
1.228
)],
[
'
OG1
'
,
4
,
(
0.472
,
1.353
,
0.000
)],
[
"
OG1
"
,
4
,
(
0.472
,
1.353
,
0.000
)],
],
],
'
TRP
'
:
[
"
TRP
"
:
[
[
'N'
,
0
,
(
-
0.521
,
1.363
,
0.000
)],
[
"N"
,
0
,
(
-
0.521
,
1.363
,
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.525
,
-
0.000
,
0.000
)],
[
"C"
,
0
,
(
1.525
,
-
0.000
,
0.000
)],
[
'
CB
'
,
0
,
(
-
0.523
,
-
0.776
,
-
1.212
)],
[
"
CB
"
,
0
,
(
-
0.523
,
-
0.776
,
-
1.212
)],
[
'O'
,
3
,
(
0.627
,
1.062
,
0.000
)],
[
"O"
,
3
,
(
0.627
,
1.062
,
0.000
)],
[
'
CG
'
,
4
,
(
0.609
,
1.370
,
-
0.000
)],
[
"
CG
"
,
4
,
(
0.609
,
1.370
,
-
0.000
)],
[
'
CD1
'
,
5
,
(
0.824
,
1.091
,
0.000
)],
[
"
CD1
"
,
5
,
(
0.824
,
1.091
,
0.000
)],
[
'
CD2
'
,
5
,
(
0.854
,
-
1.148
,
-
0.005
)],
[
"
CD2
"
,
5
,
(
0.854
,
-
1.148
,
-
0.005
)],
[
'
CE2
'
,
5
,
(
2.186
,
-
0.678
,
-
0.007
)],
[
"
CE2
"
,
5
,
(
2.186
,
-
0.678
,
-
0.007
)],
[
'
CE3
'
,
5
,
(
0.622
,
-
2.530
,
-
0.007
)],
[
"
CE3
"
,
5
,
(
0.622
,
-
2.530
,
-
0.007
)],
[
'
NE1
'
,
5
,
(
2.140
,
0.690
,
-
0.004
)],
[
"
NE1
"
,
5
,
(
2.140
,
0.690
,
-
0.004
)],
[
'
CH2
'
,
5
,
(
3.028
,
-
2.890
,
-
0.013
)],
[
"
CH2
"
,
5
,
(
3.028
,
-
2.890
,
-
0.013
)],
[
'
CZ2
'
,
5
,
(
3.283
,
-
1.543
,
-
0.011
)],
[
"
CZ2
"
,
5
,
(
3.283
,
-
1.543
,
-
0.011
)],
[
'
CZ3
'
,
5
,
(
1.715
,
-
3.389
,
-
0.011
)],
[
"
CZ3
"
,
5
,
(
1.715
,
-
3.389
,
-
0.011
)],
],
],
'
TYR
'
:
[
"
TYR
"
:
[
[
'N'
,
0
,
(
-
0.522
,
1.362
,
0.000
)],
[
"N"
,
0
,
(
-
0.522
,
1.362
,
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.524
,
-
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.524
,
-
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.522
,
-
0.776
,
-
1.213
)],
[
"
CB
"
,
0
,
(
-
0.522
,
-
0.776
,
-
1.213
)],
[
'O'
,
3
,
(
0.627
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.627
,
1.062
,
-
0.000
)],
[
'
CG
'
,
4
,
(
0.607
,
1.382
,
-
0.000
)],
[
"
CG
"
,
4
,
(
0.607
,
1.382
,
-
0.000
)],
[
'
CD1
'
,
5
,
(
0.716
,
1.195
,
-
0.000
)],
[
"
CD1
"
,
5
,
(
0.716
,
1.195
,
-
0.000
)],
[
'
CD2
'
,
5
,
(
0.713
,
-
1.194
,
-
0.001
)],
[
"
CD2
"
,
5
,
(
0.713
,
-
1.194
,
-
0.001
)],
[
'
CE1
'
,
5
,
(
2.107
,
1.200
,
-
0.002
)],
[
"
CE1
"
,
5
,
(
2.107
,
1.200
,
-
0.002
)],
[
'
CE2
'
,
5
,
(
2.104
,
-
1.201
,
-
0.003
)],
[
"
CE2
"
,
5
,
(
2.104
,
-
1.201
,
-
0.003
)],
[
'
OH
'
,
5
,
(
4.168
,
-
0.002
,
-
0.005
)],
[
"
OH
"
,
5
,
(
4.168
,
-
0.002
,
-
0.005
)],
[
'
CZ
'
,
5
,
(
2.791
,
-
0.001
,
-
0.003
)],
[
"
CZ
"
,
5
,
(
2.791
,
-
0.001
,
-
0.003
)],
],
],
'
VAL
'
:
[
"
VAL
"
:
[
[
'N'
,
0
,
(
-
0.494
,
1.373
,
-
0.000
)],
[
"N"
,
0
,
(
-
0.494
,
1.373
,
-
0.000
)],
[
'
CA
'
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
"
CA
"
,
0
,
(
0.000
,
0.000
,
0.000
)],
[
'C'
,
0
,
(
1.527
,
-
0.000
,
-
0.000
)],
[
"C"
,
0
,
(
1.527
,
-
0.000
,
-
0.000
)],
[
'
CB
'
,
0
,
(
-
0.533
,
-
0.795
,
-
1.213
)],
[
"
CB
"
,
0
,
(
-
0.533
,
-
0.795
,
-
1.213
)],
[
'O'
,
3
,
(
0.627
,
1.062
,
-
0.000
)],
[
"O"
,
3
,
(
0.627
,
1.062
,
-
0.000
)],
[
'
CG1
'
,
4
,
(
0.540
,
1.429
,
-
0.000
)],
[
"
CG1
"
,
4
,
(
0.540
,
1.429
,
-
0.000
)],
[
'
CG2
'
,
4
,
(
0.533
,
-
0.776
,
1.203
)],
[
"
CG2
"
,
4
,
(
0.533
,
-
0.776
,
1.203
)],
],
],
}
}
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
residue_atoms
=
{
residue_atoms
=
{
'ALA'
:
[
'C'
,
'CA'
,
'CB'
,
'N'
,
'O'
],
"ALA"
:
[
"C"
,
"CA"
,
"CB"
,
"N"
,
"O"
],
'ARG'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CD'
,
'CZ'
,
'N'
,
'NE'
,
'O'
,
'NH1'
,
'NH2'
],
"ARG"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"CD"
,
"CZ"
,
"N"
,
"NE"
,
"O"
,
"NH1"
,
"NH2"
],
'ASP'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'N'
,
'O'
,
'OD1'
,
'OD2'
],
"ASP"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"N"
,
"O"
,
"OD1"
,
"OD2"
],
'ASN'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'N'
,
'ND2'
,
'O'
,
'OD1'
],
"ASN"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"N"
,
"ND2"
,
"O"
,
"OD1"
],
'CYS'
:
[
'C'
,
'CA'
,
'CB'
,
'N'
,
'O'
,
'SG'
],
"CYS"
:
[
"C"
,
"CA"
,
"CB"
,
"N"
,
"O"
,
"SG"
],
'GLU'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CD'
,
'N'
,
'O'
,
'OE1'
,
'OE2'
],
"GLU"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"CD"
,
"N"
,
"O"
,
"OE1"
,
"OE2"
],
'GLN'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CD'
,
'N'
,
'NE2'
,
'O'
,
'OE1'
],
"GLN"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"CD"
,
"N"
,
"NE2"
,
"O"
,
"OE1"
],
'GLY'
:
[
'C'
,
'CA'
,
'N'
,
'O'
],
"GLY"
:
[
"C"
,
"CA"
,
"N"
,
"O"
],
'HIS'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CD2'
,
'CE1'
,
'N'
,
'ND1'
,
'NE2'
,
'O'
],
"HIS"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"CD2"
,
"CE1"
,
"N"
,
"ND1"
,
"NE2"
,
"O"
],
'ILE'
:
[
'C'
,
'CA'
,
'CB'
,
'CG1'
,
'CG2'
,
'CD1'
,
'N'
,
'O'
],
"ILE"
:
[
"C"
,
"CA"
,
"CB"
,
"CG1"
,
"CG2"
,
"CD1"
,
"N"
,
"O"
],
'LEU'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CD1'
,
'CD2'
,
'N'
,
'O'
],
"LEU"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"CD1"
,
"CD2"
,
"N"
,
"O"
],
'LYS'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CD'
,
'CE'
,
'N'
,
'NZ'
,
'O'
],
"LYS"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"CD"
,
"CE"
,
"N"
,
"NZ"
,
"O"
],
'MET'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CE'
,
'N'
,
'O'
,
'SD'
],
"MET"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"CE"
,
"N"
,
"O"
,
"SD"
],
'PHE'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CD1'
,
'CD2'
,
'CE1'
,
'CE2'
,
'CZ'
,
'N'
,
'O'
],
"PHE"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"CD1"
,
"CD2"
,
"CE1"
,
"CE2"
,
"CZ"
,
"N"
,
"O"
],
'PRO'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CD'
,
'N'
,
'O'
],
"PRO"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"CD"
,
"N"
,
"O"
],
'SER'
:
[
'C'
,
'CA'
,
'CB'
,
'N'
,
'O'
,
'OG'
],
"SER"
:
[
"C"
,
"CA"
,
"CB"
,
"N"
,
"O"
,
"OG"
],
'THR'
:
[
'C'
,
'CA'
,
'CB'
,
'CG2'
,
'N'
,
'O'
,
'OG1'
],
"THR"
:
[
"C"
,
"CA"
,
"CB"
,
"CG2"
,
"N"
,
"O"
,
"OG1"
],
'TRP'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CD1'
,
'CD2'
,
'CE2'
,
'CE3'
,
'CZ2'
,
'CZ3'
,
"TRP"
:
[
'CH2'
,
'N'
,
'NE1'
,
'O'
],
"C"
,
'TYR'
:
[
'C'
,
'CA'
,
'CB'
,
'CG'
,
'CD1'
,
'CD2'
,
'CE1'
,
'CE2'
,
'CZ'
,
'N'
,
'O'
,
"CA"
,
'OH'
],
"CB"
,
'VAL'
:
[
'C'
,
'CA'
,
'CB'
,
'CG1'
,
'CG2'
,
'N'
,
'O'
]
"CG"
,
"CD1"
,
"CD2"
,
"CE2"
,
"CE3"
,
"CZ2"
,
"CZ3"
,
"CH2"
,
"N"
,
"NE1"
,
"O"
,
],
"TYR"
:
[
"C"
,
"CA"
,
"CB"
,
"CG"
,
"CD1"
,
"CD2"
,
"CE1"
,
"CE2"
,
"CZ"
,
"N"
,
"O"
,
"OH"
,
],
"VAL"
:
[
"C"
,
"CA"
,
"CB"
,
"CG1"
,
"CG2"
,
"N"
,
"O"
],
}
}
# Naming swaps for ambiguous atom names.
# Naming swaps for ambiguous atom names.
...
@@ -368,115 +411,134 @@ residue_atoms = {
...
@@ -368,115 +411,134 @@ residue_atoms = {
# the 'ambiguous' atoms and their neighbours)
# the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
# TODO: ^ interpret this
residue_atom_renaming_swaps
=
{
residue_atom_renaming_swaps
=
{
'
ASP
'
:
{
'
OD1
'
:
'
OD2
'
},
"
ASP
"
:
{
"
OD1
"
:
"
OD2
"
},
'
GLU
'
:
{
'
OE1
'
:
'
OE2
'
},
"
GLU
"
:
{
"
OE1
"
:
"
OE2
"
},
'
PHE
'
:
{
'
CD1
'
:
'
CD2
'
,
'
CE1
'
:
'
CE2
'
},
"
PHE
"
:
{
"
CD1
"
:
"
CD2
"
,
"
CE1
"
:
"
CE2
"
},
'
TYR
'
:
{
'
CD1
'
:
'
CD2
'
,
'
CE1
'
:
'
CE2
'
},
"
TYR
"
:
{
"
CD1
"
:
"
CD2
"
,
"
CE1
"
:
"
CE2
"
},
}
}
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
van_der_waals_radius
=
{
van_der_waals_radius
=
{
'C'
:
1.7
,
"C"
:
1.7
,
'N'
:
1.55
,
"N"
:
1.55
,
'O'
:
1.52
,
"O"
:
1.52
,
'S'
:
1.8
,
"S"
:
1.8
,
}
}
Bond
=
collections
.
namedtuple
(
Bond
=
collections
.
namedtuple
(
'Bond'
,
[
'atom1_name'
,
'atom2_name'
,
'length'
,
'stddev'
])
"Bond"
,
[
"atom1_name"
,
"atom2_name"
,
"length"
,
"stddev"
]
)
BondAngle
=
collections
.
namedtuple
(
BondAngle
=
collections
.
namedtuple
(
'BondAngle'
,
"BondAngle"
,
[
'atom1_name'
,
'atom2_name'
,
'atom3name'
,
'angle_rad'
,
'stddev'
])
[
"atom1_name"
,
"atom2_name"
,
"atom3name"
,
"angle_rad"
,
"stddev"
],
)
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
load_stereo_chemical_props
()
->
Tuple
[
Mapping
[
str
,
List
[
Bond
]],
def
load_stereo_chemical_props
()
->
Tuple
[
Mapping
[
str
,
List
[
Bond
]],
Mapping
[
str
,
List
[
Bond
]],
Mapping
[
str
,
List
[
BondAngle
]]]:
Mapping
[
str
,
List
[
Bond
]],
"""Load stereo_chemical_props.txt into a nice structure.
Mapping
[
str
,
List
[
BondAngle
]],
]:
Load literature values for bond lengths and bond angles and translate
"""Load stereo_chemical_props.txt into a nice structure.
bond angles into the length of the opposite edge of the triangle
("residue_virtual_bonds").
Load literature values for bond lengths and bond angles and translate
bond angles into the length of the opposite edge of the triangle
Returns:
("residue_virtual_bonds").
residue_bonds: dict that maps resname --> list of Bond tuples
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
Returns:
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
residue_bonds: dict that maps resname --> list of Bond tuples
"""
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
# TODO: this file should be downloaded in a setup script
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
stereo_chemical_props_path
=
(
"""
'openfold/resources/stereo_chemical_props.txt'
)
# TODO: this file should be downloaded in a setup script
with
open
(
stereo_chemical_props_path
,
'rt'
)
as
f
:
stereo_chemical_props_path
=
"openfold/resources/stereo_chemical_props.txt"
stereo_chemical_props
=
f
.
read
()
with
open
(
stereo_chemical_props_path
,
"rt"
)
as
f
:
lines_iter
=
iter
(
stereo_chemical_props
.
splitlines
())
stereo_chemical_props
=
f
.
read
()
# Load bond lengths.
lines_iter
=
iter
(
stereo_chemical_props
.
splitlines
())
residue_bonds
=
{}
# Load bond lengths.
next
(
lines_iter
)
# Skip header line.
residue_bonds
=
{}
for
line
in
lines_iter
:
next
(
lines_iter
)
# Skip header line.
if
line
.
strip
()
==
'-'
:
for
line
in
lines_iter
:
break
if
line
.
strip
()
==
"-"
:
bond
,
resname
,
length
,
stddev
=
line
.
split
()
break
atom1
,
atom2
=
bond
.
split
(
'-'
)
bond
,
resname
,
length
,
stddev
=
line
.
split
()
if
resname
not
in
residue_bonds
:
atom1
,
atom2
=
bond
.
split
(
"-"
)
residue_bonds
[
resname
]
=
[]
if
resname
not
in
residue_bonds
:
residue_bonds
[
resname
].
append
(
residue_bonds
[
resname
]
=
[]
Bond
(
atom1
,
atom2
,
float
(
length
),
float
(
stddev
)))
residue_bonds
[
resname
].
append
(
residue_bonds
[
'UNK'
]
=
[]
Bond
(
atom1
,
atom2
,
float
(
length
),
float
(
stddev
))
)
# Load bond angles.
residue_bonds
[
"UNK"
]
=
[]
residue_bond_angles
=
{}
next
(
lines_iter
)
# Skip empty line.
# Load bond angles.
next
(
lines_iter
)
# Skip header line.
residue_bond_angles
=
{}
for
line
in
lines_iter
:
next
(
lines_iter
)
# Skip empty line.
if
line
.
strip
()
==
'-'
:
next
(
lines_iter
)
# Skip header line.
break
for
line
in
lines_iter
:
bond
,
resname
,
angle_degree
,
stddev_degree
=
line
.
split
()
if
line
.
strip
()
==
"-"
:
atom1
,
atom2
,
atom3
=
bond
.
split
(
'-'
)
break
if
resname
not
in
residue_bond_angles
:
bond
,
resname
,
angle_degree
,
stddev_degree
=
line
.
split
()
residue_bond_angles
[
resname
]
=
[]
atom1
,
atom2
,
atom3
=
bond
.
split
(
"-"
)
residue_bond_angles
[
resname
].
append
(
if
resname
not
in
residue_bond_angles
:
BondAngle
(
atom1
,
atom2
,
atom3
,
residue_bond_angles
[
resname
]
=
[]
float
(
angle_degree
)
/
180.
*
np
.
pi
,
residue_bond_angles
[
resname
].
append
(
float
(
stddev_degree
)
/
180.
*
np
.
pi
))
BondAngle
(
residue_bond_angles
[
'UNK'
]
=
[]
atom1
,
atom2
,
def
make_bond_key
(
atom1_name
,
atom2_name
):
atom3
,
"""Unique key to lookup bonds."""
float
(
angle_degree
)
/
180.0
*
np
.
pi
,
return
'-'
.
join
(
sorted
([
atom1_name
,
atom2_name
]))
float
(
stddev_degree
)
/
180.0
*
np
.
pi
,
)
# Translate bond angles into distances ("virtual bonds").
)
residue_virtual_bonds
=
{}
residue_bond_angles
[
"UNK"
]
=
[]
for
resname
,
bond_angles
in
residue_bond_angles
.
items
():
# Create a fast lookup dict for bond lengths.
def
make_bond_key
(
atom1_name
,
atom2_name
):
bond_cache
=
{}
"""Unique key to lookup bonds."""
for
b
in
residue_bonds
[
resname
]:
return
"-"
.
join
(
sorted
([
atom1_name
,
atom2_name
]))
bond_cache
[
make_bond_key
(
b
.
atom1_name
,
b
.
atom2_name
)]
=
b
residue_virtual_bonds
[
resname
]
=
[]
# Translate bond angles into distances ("virtual bonds").
for
ba
in
bond_angles
:
residue_virtual_bonds
=
{}
bond1
=
bond_cache
[
make_bond_key
(
ba
.
atom1_name
,
ba
.
atom2_name
)]
for
resname
,
bond_angles
in
residue_bond_angles
.
items
():
bond2
=
bond_cache
[
make_bond_key
(
ba
.
atom2_name
,
ba
.
atom3name
)]
# Create a fast lookup dict for bond lengths.
bond_cache
=
{}
# Compute distance between atom1 and atom3 using the law of cosines
for
b
in
residue_bonds
[
resname
]:
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
bond_cache
[
make_bond_key
(
b
.
atom1_name
,
b
.
atom2_name
)]
=
b
gamma
=
ba
.
angle_rad
residue_virtual_bonds
[
resname
]
=
[]
length
=
np
.
sqrt
(
bond1
.
length
**
2
+
bond2
.
length
**
2
for
ba
in
bond_angles
:
-
2
*
bond1
.
length
*
bond2
.
length
*
np
.
cos
(
gamma
))
bond1
=
bond_cache
[
make_bond_key
(
ba
.
atom1_name
,
ba
.
atom2_name
)]
bond2
=
bond_cache
[
make_bond_key
(
ba
.
atom2_name
,
ba
.
atom3name
)]
# Propagation of uncertainty assuming uncorrelated errors.
dl_outer
=
0.5
/
length
# Compute distance between atom1 and atom3 using the law of cosines
dl_dgamma
=
(
2
*
bond1
.
length
*
bond2
.
length
*
np
.
sin
(
gamma
))
*
dl_outer
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
dl_db1
=
(
2
*
bond1
.
length
-
2
*
bond2
.
length
*
np
.
cos
(
gamma
))
*
dl_outer
gamma
=
ba
.
angle_rad
dl_db2
=
(
2
*
bond2
.
length
-
2
*
bond1
.
length
*
np
.
cos
(
gamma
))
*
dl_outer
length
=
np
.
sqrt
(
stddev
=
np
.
sqrt
((
dl_dgamma
*
ba
.
stddev
)
**
2
+
bond1
.
length
**
2
(
dl_db1
*
bond1
.
stddev
)
**
2
+
+
bond2
.
length
**
2
(
dl_db2
*
bond2
.
stddev
)
**
2
)
-
2
*
bond1
.
length
*
bond2
.
length
*
np
.
cos
(
gamma
)
residue_virtual_bonds
[
resname
].
append
(
)
Bond
(
ba
.
atom1_name
,
ba
.
atom3name
,
length
,
stddev
))
# Propagation of uncertainty assuming uncorrelated errors.
return
(
residue_bonds
,
dl_outer
=
0.5
/
length
residue_virtual_bonds
,
dl_dgamma
=
(
residue_bond_angles
)
2
*
bond1
.
length
*
bond2
.
length
*
np
.
sin
(
gamma
)
)
*
dl_outer
dl_db1
=
(
2
*
bond1
.
length
-
2
*
bond2
.
length
*
np
.
cos
(
gamma
)
)
*
dl_outer
dl_db2
=
(
2
*
bond2
.
length
-
2
*
bond1
.
length
*
np
.
cos
(
gamma
)
)
*
dl_outer
stddev
=
np
.
sqrt
(
(
dl_dgamma
*
ba
.
stddev
)
**
2
+
(
dl_db1
*
bond1
.
stddev
)
**
2
+
(
dl_db2
*
bond2
.
stddev
)
**
2
)
residue_virtual_bonds
[
resname
].
append
(
Bond
(
ba
.
atom1_name
,
ba
.
atom3name
,
length
,
stddev
)
)
return
(
residue_bonds
,
residue_virtual_bonds
,
residue_bond_angles
)
# Between-residue bond lengths for general bonds (first element) and for Proline
# Between-residue bond lengths for general bonds (first element) and for Proline
...
@@ -491,10 +553,43 @@ between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
...
@@ -491,10 +553,43 @@ between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
# This mapping is used when we need to store atom data in a format that requires
# This mapping is used when we need to store atom data in a format that requires
# fixed atom data size for every residue (e.g. a numpy array).
# fixed atom data size for every residue (e.g. a numpy array).
atom_types
=
[
atom_types
=
[
'N'
,
'CA'
,
'C'
,
'CB'
,
'O'
,
'CG'
,
'CG1'
,
'CG2'
,
'OG'
,
'OG1'
,
'SG'
,
'CD'
,
"N"
,
'CD1'
,
'CD2'
,
'ND1'
,
'ND2'
,
'OD1'
,
'OD2'
,
'SD'
,
'CE'
,
'CE1'
,
'CE2'
,
'CE3'
,
"CA"
,
'NE'
,
'NE1'
,
'NE2'
,
'OE1'
,
'OE2'
,
'CH2'
,
'NH1'
,
'NH2'
,
'OH'
,
'CZ'
,
'CZ2'
,
"C"
,
'CZ3'
,
'NZ'
,
'OXT'
"CB"
,
"O"
,
"CG"
,
"CG1"
,
"CG2"
,
"OG"
,
"OG1"
,
"SG"
,
"CD"
,
"CD1"
,
"CD2"
,
"ND1"
,
"ND2"
,
"OD1"
,
"OD2"
,
"SD"
,
"CE"
,
"CE1"
,
"CE2"
,
"CE3"
,
"NE"
,
"NE1"
,
"NE2"
,
"OE1"
,
"OE2"
,
"CH2"
,
"NH1"
,
"NH2"
,
"OH"
,
"CZ"
,
"CZ2"
,
"CZ3"
,
"NZ"
,
"OXT"
,
]
]
atom_order
=
{
atom_type
:
i
for
i
,
atom_type
in
enumerate
(
atom_types
)}
atom_order
=
{
atom_type
:
i
for
i
,
atom_type
in
enumerate
(
atom_types
)}
atom_type_num
=
len
(
atom_types
)
# := 37.
atom_type_num
=
len
(
atom_types
)
# := 37.
...
@@ -503,28 +598,252 @@ atom_type_num = len(atom_types) # := 37.
...
@@ -503,28 +598,252 @@ atom_type_num = len(atom_types) # := 37.
# pylint: disable=line-too-long
# pylint: disable=line-too-long
# pylint: disable=bad-whitespace
# pylint: disable=bad-whitespace
restype_name_to_atom14_names
=
{
restype_name_to_atom14_names
=
{
'ALA'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
],
"ALA"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
],
'ARG'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'CD'
,
'NE'
,
'CZ'
,
'NH1'
,
'NH2'
,
''
,
''
,
''
],
"ARG"
:
[
'ASN'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'OD1'
,
'ND2'
,
''
,
''
,
''
,
''
,
''
,
''
],
"N"
,
'ASP'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'OD1'
,
'OD2'
,
''
,
''
,
''
,
''
,
''
,
''
],
"CA"
,
'CYS'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'SG'
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
],
"C"
,
'GLN'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'CD'
,
'OE1'
,
'NE2'
,
''
,
''
,
''
,
''
,
''
],
"O"
,
'GLU'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'CD'
,
'OE1'
,
'OE2'
,
''
,
''
,
''
,
''
,
''
],
"CB"
,
'GLY'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
],
"CG"
,
'HIS'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'ND1'
,
'CD2'
,
'CE1'
,
'NE2'
,
''
,
''
,
''
,
''
],
"CD"
,
'ILE'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG1'
,
'CG2'
,
'CD1'
,
''
,
''
,
''
,
''
,
''
,
''
],
"NE"
,
'LEU'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'CD1'
,
'CD2'
,
''
,
''
,
''
,
''
,
''
,
''
],
"CZ"
,
'LYS'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'CD'
,
'CE'
,
'NZ'
,
''
,
''
,
''
,
''
,
''
],
"NH1"
,
'MET'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'SD'
,
'CE'
,
''
,
''
,
''
,
''
,
''
,
''
],
"NH2"
,
'PHE'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'CD1'
,
'CD2'
,
'CE1'
,
'CE2'
,
'CZ'
,
''
,
''
,
''
],
""
,
'PRO'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'CD'
,
''
,
''
,
''
,
''
,
''
,
''
,
''
],
""
,
'SER'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'OG'
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
],
""
,
'THR'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'OG1'
,
'CG2'
,
''
,
''
,
''
,
''
,
''
,
''
,
''
],
],
'TRP'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'CD1'
,
'CD2'
,
'NE1'
,
'CE2'
,
'CE3'
,
'CZ2'
,
'CZ3'
,
'CH2'
],
"ASN"
:
[
'TYR'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG'
,
'CD1'
,
'CD2'
,
'CE1'
,
'CE2'
,
'CZ'
,
'OH'
,
''
,
''
],
"N"
,
'VAL'
:
[
'N'
,
'CA'
,
'C'
,
'O'
,
'CB'
,
'CG1'
,
'CG2'
,
''
,
''
,
''
,
''
,
''
,
''
,
''
],
"CA"
,
'UNK'
:
[
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
],
"C"
,
"O"
,
"CB"
,
"CG"
,
"OD1"
,
"ND2"
,
""
,
""
,
""
,
""
,
""
,
""
,
],
"ASP"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"OD1"
,
"OD2"
,
""
,
""
,
""
,
""
,
""
,
""
,
],
"CYS"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"SG"
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
],
"GLN"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"CD"
,
"OE1"
,
"NE2"
,
""
,
""
,
""
,
""
,
""
,
],
"GLU"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"CD"
,
"OE1"
,
"OE2"
,
""
,
""
,
""
,
""
,
""
,
],
"GLY"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
],
"HIS"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"ND1"
,
"CD2"
,
"CE1"
,
"NE2"
,
""
,
""
,
""
,
""
,
],
"ILE"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG1"
,
"CG2"
,
"CD1"
,
""
,
""
,
""
,
""
,
""
,
""
,
],
"LEU"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"CD1"
,
"CD2"
,
""
,
""
,
""
,
""
,
""
,
""
,
],
"LYS"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"CD"
,
"CE"
,
"NZ"
,
""
,
""
,
""
,
""
,
""
,
],
"MET"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"SD"
,
"CE"
,
""
,
""
,
""
,
""
,
""
,
""
,
],
"PHE"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"CD1"
,
"CD2"
,
"CE1"
,
"CE2"
,
"CZ"
,
""
,
""
,
""
,
],
"PRO"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"CD"
,
""
,
""
,
""
,
""
,
""
,
""
,
""
],
"SER"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"OG"
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
],
"THR"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"OG1"
,
"CG2"
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
],
"TRP"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"CD1"
,
"CD2"
,
"NE1"
,
"CE2"
,
"CE3"
,
"CZ2"
,
"CZ3"
,
"CH2"
,
],
"TYR"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG"
,
"CD1"
,
"CD2"
,
"CE1"
,
"CE2"
,
"CZ"
,
"OH"
,
""
,
""
,
],
"VAL"
:
[
"N"
,
"CA"
,
"C"
,
"O"
,
"CB"
,
"CG1"
,
"CG2"
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
],
"UNK"
:
[
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
,
""
],
}
}
# pylint: enable=line-too-long
# pylint: enable=line-too-long
# pylint: enable=bad-whitespace
# pylint: enable=bad-whitespace
...
@@ -533,81 +852,102 @@ restype_name_to_atom14_names = {
...
@@ -533,81 +852,102 @@ restype_name_to_atom14_names = {
# This is the standard residue order when coding AA type as a number.
# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes
=
[
restypes
=
[
'A'
,
'R'
,
'N'
,
'D'
,
'C'
,
'Q'
,
'E'
,
'G'
,
'H'
,
'I'
,
'L'
,
'K'
,
'M'
,
'F'
,
'P'
,
"A"
,
'S'
,
'T'
,
'W'
,
'Y'
,
'V'
"R"
,
"N"
,
"D"
,
"C"
,
"Q"
,
"E"
,
"G"
,
"H"
,
"I"
,
"L"
,
"K"
,
"M"
,
"F"
,
"P"
,
"S"
,
"T"
,
"W"
,
"Y"
,
"V"
,
]
]
restype_order
=
{
restype
:
i
for
i
,
restype
in
enumerate
(
restypes
)}
restype_order
=
{
restype
:
i
for
i
,
restype
in
enumerate
(
restypes
)}
restype_num
=
len
(
restypes
)
# := 20.
restype_num
=
len
(
restypes
)
# := 20.
unk_restype_index
=
restype_num
# Catch-all index for unknown restypes.
unk_restype_index
=
restype_num
# Catch-all index for unknown restypes.
restypes_with_x
=
restypes
+
[
'X'
]
restypes_with_x
=
restypes
+
[
"X"
]
restype_order_with_x
=
{
restype
:
i
for
i
,
restype
in
enumerate
(
restypes_with_x
)}
restype_order_with_x
=
{
restype
:
i
for
i
,
restype
in
enumerate
(
restypes_with_x
)}
def
sequence_to_onehot
(
def
sequence_to_onehot
(
sequence
:
str
,
sequence
:
str
,
mapping
:
Mapping
[
str
,
int
],
map_unknown_to_x
:
bool
=
False
mapping
:
Mapping
[
str
,
int
],
)
->
np
.
ndarray
:
map_unknown_to_x
:
bool
=
False
)
->
np
.
ndarray
:
"""Maps the given sequence into a one-hot encoded matrix.
"""Maps the given sequence into a one-hot encoded matrix.
Args:
Args:
sequence: An amino acid sequence.
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
amino acid 'X', an error will be thrown. If False, any amino acid not in
amino acid 'X', an error will be thrown. If False, any amino acid not in
the mapping will throw an error.
the mapping will throw an error.
Returns:
Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
the sequence.
the sequence.
Raises:
Raises:
ValueError: If the mapping doesn't contain values from 0 to
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
num_unique_aas - 1 without any gaps.
"""
"""
num_entries
=
max
(
mapping
.
values
())
+
1
num_entries
=
max
(
mapping
.
values
())
+
1
if
sorted
(
set
(
mapping
.
values
()))
!=
list
(
range
(
num_entries
)):
if
sorted
(
set
(
mapping
.
values
()))
!=
list
(
range
(
num_entries
)):
raise
ValueError
(
raise
ValueError
(
'The mapping must have values from 0 to num_unique_aas-1 '
"The mapping must have values from 0 to num_unique_aas-1 "
'without any gaps. Got: %s'
%
sorted
(
mapping
.
values
()))
"without any gaps. Got: %s"
%
sorted
(
mapping
.
values
())
)
one_hot_arr
=
np
.
zeros
((
len
(
sequence
),
num_entries
),
dtype
=
np
.
int32
)
one_hot_arr
=
np
.
zeros
((
len
(
sequence
),
num_entries
),
dtype
=
np
.
int32
)
for
aa_index
,
aa_type
in
enumerate
(
sequence
):
if
map_unknown_to_x
:
for
aa_index
,
aa_type
in
enumerate
(
sequence
):
if
aa_type
.
isalpha
()
and
aa_type
.
isupper
():
if
map_unknown_to_x
:
aa_id
=
mapping
.
get
(
aa_type
,
mapping
[
'X'
])
if
aa_type
.
isalpha
()
and
aa_type
.
isupper
():
else
:
aa_id
=
mapping
.
get
(
aa_type
,
mapping
[
"X"
])
raise
ValueError
(
f
'Invalid character in the sequence:
{
aa_type
}
'
)
else
:
else
:
raise
ValueError
(
aa_id
=
mapping
[
aa_type
]
f
"Invalid character in the sequence:
{
aa_type
}
"
one_hot_arr
[
aa_index
,
aa_id
]
=
1
)
else
:
return
one_hot_arr
aa_id
=
mapping
[
aa_type
]
one_hot_arr
[
aa_index
,
aa_id
]
=
1
return
one_hot_arr
restype_1to3
=
{
restype_1to3
=
{
'A'
:
'
ALA
'
,
"A"
:
"
ALA
"
,
'R'
:
'
ARG
'
,
"R"
:
"
ARG
"
,
'N'
:
'
ASN
'
,
"N"
:
"
ASN
"
,
'D'
:
'
ASP
'
,
"D"
:
"
ASP
"
,
'C'
:
'
CYS
'
,
"C"
:
"
CYS
"
,
'Q'
:
'
GLN
'
,
"Q"
:
"
GLN
"
,
'E'
:
'
GLU
'
,
"E"
:
"
GLU
"
,
'G'
:
'
GLY
'
,
"G"
:
"
GLY
"
,
'H'
:
'
HIS
'
,
"H"
:
"
HIS
"
,
'I'
:
'
ILE
'
,
"I"
:
"
ILE
"
,
'L'
:
'
LEU
'
,
"L"
:
"
LEU
"
,
'K'
:
'
LYS
'
,
"K"
:
"
LYS
"
,
'M'
:
'
MET
'
,
"M"
:
"
MET
"
,
'F'
:
'
PHE
'
,
"F"
:
"
PHE
"
,
'P'
:
'
PRO
'
,
"P"
:
"
PRO
"
,
'S'
:
'
SER
'
,
"S"
:
"
SER
"
,
'T'
:
'
THR
'
,
"T"
:
"
THR
"
,
'W'
:
'
TRP
'
,
"W"
:
"
TRP
"
,
'Y'
:
'
TYR
'
,
"Y"
:
"
TYR
"
,
'V'
:
'
VAL
'
,
"V"
:
"
VAL
"
,
}
}
...
@@ -618,7 +958,7 @@ restype_1to3 = {
...
@@ -618,7 +958,7 @@ restype_1to3 = {
restype_3to1
=
{
v
:
k
for
k
,
v
in
restype_1to3
.
items
()}
restype_3to1
=
{
v
:
k
for
k
,
v
in
restype_1to3
.
items
()}
# Define a restype name for all unknown residues.
# Define a restype name for all unknown residues.
unk_restype
=
'
UNK
'
unk_restype
=
"
UNK
"
resnames
=
[
restype_1to3
[
r
]
for
r
in
restypes
]
+
[
unk_restype
]
resnames
=
[
restype_1to3
[
r
]
for
r
in
restypes
]
+
[
unk_restype
]
resname_to_idx
=
{
resname
:
i
for
i
,
resname
in
enumerate
(
resnames
)}
resname_to_idx
=
{
resname
:
i
for
i
,
resname
in
enumerate
(
resnames
)}
...
@@ -632,78 +972,79 @@ resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
...
@@ -632,78 +972,79 @@ resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
# codes is put at the end (20 and 21) so that they can easily be ignored if
# codes is put at the end (20 and 21) so that they can easily be ignored if
# desired.
# desired.
HHBLITS_AA_TO_ID
=
{
HHBLITS_AA_TO_ID
=
{
'A'
:
0
,
"A"
:
0
,
'B'
:
2
,
"B"
:
2
,
'C'
:
1
,
"C"
:
1
,
'D'
:
2
,
"D"
:
2
,
'E'
:
3
,
"E"
:
3
,
'F'
:
4
,
"F"
:
4
,
'G'
:
5
,
"G"
:
5
,
'H'
:
6
,
"H"
:
6
,
'I'
:
7
,
"I"
:
7
,
'J'
:
20
,
"J"
:
20
,
'K'
:
8
,
"K"
:
8
,
'L'
:
9
,
"L"
:
9
,
'M'
:
10
,
"M"
:
10
,
'N'
:
11
,
"N"
:
11
,
'O'
:
20
,
"O"
:
20
,
'P'
:
12
,
"P"
:
12
,
'Q'
:
13
,
"Q"
:
13
,
'R'
:
14
,
"R"
:
14
,
'S'
:
15
,
"S"
:
15
,
'T'
:
16
,
"T"
:
16
,
'U'
:
1
,
"U"
:
1
,
'V'
:
17
,
"V"
:
17
,
'W'
:
18
,
"W"
:
18
,
'X'
:
20
,
"X"
:
20
,
'Y'
:
19
,
"Y"
:
19
,
'Z'
:
3
,
"Z"
:
3
,
'-'
:
21
,
"-"
:
21
,
}
}
# Partial inversion of HHBLITS_AA_TO_ID.
# Partial inversion of HHBLITS_AA_TO_ID.
ID_TO_HHBLITS_AA
=
{
ID_TO_HHBLITS_AA
=
{
0
:
'A'
,
0
:
"A"
,
1
:
'C'
,
# Also U.
1
:
"C"
,
# Also U.
2
:
'D'
,
# Also B.
2
:
"D"
,
# Also B.
3
:
'E'
,
# Also Z.
3
:
"E"
,
# Also Z.
4
:
'F'
,
4
:
"F"
,
5
:
'G'
,
5
:
"G"
,
6
:
'H'
,
6
:
"H"
,
7
:
'I'
,
7
:
"I"
,
8
:
'K'
,
8
:
"K"
,
9
:
'L'
,
9
:
"L"
,
10
:
'M'
,
10
:
"M"
,
11
:
'N'
,
11
:
"N"
,
12
:
'P'
,
12
:
"P"
,
13
:
'Q'
,
13
:
"Q"
,
14
:
'R'
,
14
:
"R"
,
15
:
'S'
,
15
:
"S"
,
16
:
'T'
,
16
:
"T"
,
17
:
'V'
,
17
:
"V"
,
18
:
'W'
,
18
:
"W"
,
19
:
'Y'
,
19
:
"Y"
,
20
:
'X'
,
# Includes J and O.
20
:
"X"
,
# Includes J and O.
21
:
'-'
,
21
:
"-"
,
}
}
restypes_with_x_and_gap
=
restypes
+
[
'X'
,
'-'
]
restypes_with_x_and_gap
=
restypes
+
[
"X"
,
"-"
]
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
=
tuple
(
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
=
tuple
(
restypes_with_x_and_gap
.
index
(
ID_TO_HHBLITS_AA
[
i
])
restypes_with_x_and_gap
.
index
(
ID_TO_HHBLITS_AA
[
i
])
for
i
in
range
(
len
(
restypes_with_x_and_gap
)))
for
i
in
range
(
len
(
restypes_with_x_and_gap
))
)
def
_make_standard_atom_mask
()
->
np
.
ndarray
:
def
_make_standard_atom_mask
()
->
np
.
ndarray
:
"""Returns [num_res_types, num_atom_types] mask array."""
"""Returns [num_res_types, num_atom_types] mask array."""
# +1 to account for unknown (all 0s).
# +1 to account for unknown (all 0s).
mask
=
np
.
zeros
([
restype_num
+
1
,
atom_type_num
],
dtype
=
np
.
int32
)
mask
=
np
.
zeros
([
restype_num
+
1
,
atom_type_num
],
dtype
=
np
.
int32
)
for
restype
,
restype_letter
in
enumerate
(
restypes
):
for
restype
,
restype_letter
in
enumerate
(
restypes
):
restype_name
=
restype_1to3
[
restype_letter
]
restype_name
=
restype_1to3
[
restype_letter
]
atom_names
=
residue_atoms
[
restype_name
]
atom_names
=
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
for
atom_name
in
atom_names
:
atom_type
=
atom_order
[
atom_name
]
atom_type
=
atom_order
[
atom_name
]
mask
[
restype
,
atom_type
]
=
1
mask
[
restype
,
atom_type
]
=
1
return
mask
return
mask
STANDARD_ATOM_MASK
=
_make_standard_atom_mask
()
STANDARD_ATOM_MASK
=
_make_standard_atom_mask
()
...
@@ -712,25 +1053,26 @@ STANDARD_ATOM_MASK = _make_standard_atom_mask()
...
@@ -712,25 +1053,26 @@ STANDARD_ATOM_MASK = _make_standard_atom_mask()
# A one hot representation for the first and second atoms defining the axis
# A one hot representation for the first and second atoms defining the axis
# of rotation for each chi-angle in each residue.
# of rotation for each chi-angle in each residue.
def
chi_angle_atom
(
atom_index
:
int
)
->
np
.
ndarray
:
def
chi_angle_atom
(
atom_index
:
int
)
->
np
.
ndarray
:
"""Define chi-angle rigid groups via one-hot representations."""
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index
=
{}
chi_angles_index
=
{}
one_hots
=
[]
one_hots
=
[]
for
k
,
v
in
chi_angles_atoms
.
items
():
for
k
,
v
in
chi_angles_atoms
.
items
():
indices
=
[
atom_types
.
index
(
s
[
atom_index
])
for
s
in
v
]
indices
=
[
atom_types
.
index
(
s
[
atom_index
])
for
s
in
v
]
indices
.
extend
([
-
1
]
*
(
4
-
len
(
indices
)))
indices
.
extend
([
-
1
]
*
(
4
-
len
(
indices
)))
chi_angles_index
[
k
]
=
indices
chi_angles_index
[
k
]
=
indices
for
r
in
restypes
:
for
r
in
restypes
:
res3
=
restype_1to3
[
r
]
res3
=
restype_1to3
[
r
]
one_hot
=
np
.
eye
(
atom_type_num
)[
chi_angles_index
[
res3
]]
one_hot
=
np
.
eye
(
atom_type_num
)[
chi_angles_index
[
res3
]]
one_hots
.
append
(
one_hot
)
one_hots
.
append
(
one_hot
)
one_hots
.
append
(
np
.
zeros
([
4
,
atom_type_num
]))
# Add zeros for residue `X`.
one_hots
.
append
(
np
.
zeros
([
4
,
atom_type_num
]))
# Add zeros for residue `X`.
one_hot
=
np
.
stack
(
one_hots
,
axis
=
0
)
one_hot
=
np
.
stack
(
one_hots
,
axis
=
0
)
one_hot
=
np
.
transpose
(
one_hot
,
[
0
,
2
,
1
])
one_hot
=
np
.
transpose
(
one_hot
,
[
0
,
2
,
1
])
return
one_hot
return
one_hot
chi_atom_1_one_hot
=
chi_angle_atom
(
1
)
chi_atom_1_one_hot
=
chi_angle_atom
(
1
)
chi_atom_2_one_hot
=
chi_angle_atom
(
2
)
chi_atom_2_one_hot
=
chi_angle_atom
(
2
)
...
@@ -738,35 +1080,41 @@ chi_atom_2_one_hot = chi_angle_atom(2)
...
@@ -738,35 +1080,41 @@ chi_atom_2_one_hot = chi_angle_atom(2)
# An array like chi_angles_atoms but using indices rather than names.
# An array like chi_angles_atoms but using indices rather than names.
chi_angles_atom_indices
=
[
chi_angles_atoms
[
restype_1to3
[
r
]]
for
r
in
restypes
]
chi_angles_atom_indices
=
[
chi_angles_atoms
[
restype_1to3
[
r
]]
for
r
in
restypes
]
chi_angles_atom_indices
=
tree
.
map_structure
(
chi_angles_atom_indices
=
tree
.
map_structure
(
lambda
atom_name
:
atom_order
[
atom_name
],
chi_angles_atom_indices
)
lambda
atom_name
:
atom_order
[
atom_name
],
chi_angles_atom_indices
chi_angles_atom_indices
=
np
.
array
([
)
chi_atoms
+
([[
0
,
0
,
0
,
0
]]
*
(
4
-
len
(
chi_atoms
)))
chi_angles_atom_indices
=
np
.
array
(
for
chi_atoms
in
chi_angles_atom_indices
])
[
chi_atoms
+
([[
0
,
0
,
0
,
0
]]
*
(
4
-
len
(
chi_atoms
)))
for
chi_atoms
in
chi_angles_atom_indices
]
)
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
# and atom index within that group.
# and atom index within that group.
chi_groups_for_atom
=
collections
.
defaultdict
(
list
)
chi_groups_for_atom
=
collections
.
defaultdict
(
list
)
for
res_name
,
chi_angle_atoms_for_res
in
chi_angles_atoms
.
items
():
for
res_name
,
chi_angle_atoms_for_res
in
chi_angles_atoms
.
items
():
for
chi_group_i
,
chi_group
in
enumerate
(
chi_angle_atoms_for_res
):
for
chi_group_i
,
chi_group
in
enumerate
(
chi_angle_atoms_for_res
):
for
atom_i
,
atom
in
enumerate
(
chi_group
):
for
atom_i
,
atom
in
enumerate
(
chi_group
):
chi_groups_for_atom
[(
res_name
,
atom
)].
append
((
chi_group_i
,
atom_i
))
chi_groups_for_atom
[(
res_name
,
atom
)].
append
((
chi_group_i
,
atom_i
))
chi_groups_for_atom
=
dict
(
chi_groups_for_atom
)
chi_groups_for_atom
=
dict
(
chi_groups_for_atom
)
def
_make_rigid_transformation_4x4
(
ex
,
ey
,
translation
):
def
_make_rigid_transformation_4x4
(
ex
,
ey
,
translation
):
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
# Normalize ex.
ex_normalized
=
ex
/
np
.
linalg
.
norm
(
ex
)
ex_normalized
=
ex
/
np
.
linalg
.
norm
(
ex
)
# make ey perpendicular to ex
# make ey perpendicular to ex
ey_normalized
=
ey
-
np
.
dot
(
ey
,
ex_normalized
)
*
ex_normalized
ey_normalized
=
ey
-
np
.
dot
(
ey
,
ex_normalized
)
*
ex_normalized
ey_normalized
/=
np
.
linalg
.
norm
(
ey_normalized
)
ey_normalized
/=
np
.
linalg
.
norm
(
ey_normalized
)
# compute ez as cross product
# compute ez as cross product
eznorm
=
np
.
cross
(
ex_normalized
,
ey_normalized
)
eznorm
=
np
.
cross
(
ex_normalized
,
ey_normalized
)
m
=
np
.
stack
([
ex_normalized
,
ey_normalized
,
eznorm
,
translation
]).
transpose
()
m
=
np
.
stack
(
m
=
np
.
concatenate
([
m
,
[[
0.
,
0.
,
0.
,
1.
]]],
axis
=
0
)
[
ex_normalized
,
ey_normalized
,
eznorm
,
translation
]
return
m
).
transpose
()
m
=
np
.
concatenate
([
m
,
[[
0.0
,
0.0
,
0.0
,
1.0
]]],
axis
=
0
)
return
m
# create an array with (restype, atomtype) --> rigid_group_idx
# create an array with (restype, atomtype) --> rigid_group_idx
...
@@ -783,138 +1131,173 @@ restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
...
@@ -783,138 +1131,173 @@ restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
def
_make_rigid_group_constants
():
def
_make_rigid_group_constants
():
"""Fill the arrays above."""
"""Fill the arrays above."""
for
restype
,
restype_letter
in
enumerate
(
restypes
):
for
restype
,
restype_letter
in
enumerate
(
restypes
):
resname
=
restype_1to3
[
restype_letter
]
resname
=
restype_1to3
[
restype_letter
]
for
atomname
,
group_idx
,
atom_position
in
rigid_group_atom_positions
[
for
atomname
,
group_idx
,
atom_position
in
rigid_group_atom_positions
[
resname
]:
resname
atomtype
=
atom_order
[
atomname
]
]:
restype_atom37_to_rigid_group
[
restype
,
atomtype
]
=
group_idx
atomtype
=
atom_order
[
atomname
]
restype_atom37_mask
[
restype
,
atomtype
]
=
1
restype_atom37_to_rigid_group
[
restype
,
atomtype
]
=
group_idx
restype_atom37_rigid_group_positions
[
restype
,
atomtype
,
:]
=
atom_position
restype_atom37_mask
[
restype
,
atomtype
]
=
1
restype_atom37_rigid_group_positions
[
atom14idx
=
restype_name_to_atom14_names
[
resname
].
index
(
atomname
)
restype
,
atomtype
,
:
restype_atom14_to_rigid_group
[
restype
,
atom14idx
]
=
group_idx
]
=
atom_position
restype_atom14_mask
[
restype
,
atom14idx
]
=
1
restype_atom14_rigid_group_positions
[
restype
,
atom14idx
=
restype_name_to_atom14_names
[
resname
].
index
(
atomname
)
atom14idx
,
:]
=
atom_position
restype_atom14_to_rigid_group
[
restype
,
atom14idx
]
=
group_idx
restype_atom14_mask
[
restype
,
atom14idx
]
=
1
for
restype
,
restype_letter
in
enumerate
(
restypes
):
restype_atom14_rigid_group_positions
[
resname
=
restype_1to3
[
restype_letter
]
restype
,
atom14idx
,
:
atom_positions
=
{
name
:
np
.
array
(
pos
)
for
name
,
_
,
pos
]
=
atom_position
in
rigid_group_atom_positions
[
resname
]}
for
restype
,
restype_letter
in
enumerate
(
restypes
):
# backbone to backbone is the identity transform
resname
=
restype_1to3
[
restype_letter
]
restype_rigid_group_default_frame
[
restype
,
0
,
:,
:]
=
np
.
eye
(
4
)
atom_positions
=
{
name
:
np
.
array
(
pos
)
# pre-omega-frame to backbone (currently dummy identity matrix)
for
name
,
_
,
pos
in
rigid_group_atom_positions
[
resname
]
restype_rigid_group_default_frame
[
restype
,
1
,
:,
:]
=
np
.
eye
(
4
)
}
# phi-frame to backbone
# backbone to backbone is the identity transform
mat
=
_make_rigid_transformation_4x4
(
restype_rigid_group_default_frame
[
restype
,
0
,
:,
:]
=
np
.
eye
(
4
)
ex
=
atom_positions
[
'N'
]
-
atom_positions
[
'CA'
],
ey
=
np
.
array
([
1.
,
0.
,
0.
]),
# pre-omega-frame to backbone (currently dummy identity matrix)
translation
=
atom_positions
[
'N'
])
restype_rigid_group_default_frame
[
restype
,
1
,
:,
:]
=
np
.
eye
(
4
)
restype_rigid_group_default_frame
[
restype
,
2
,
:,
:]
=
mat
# phi-frame to backbone
# psi-frame to backbone
mat
=
_make_rigid_transformation_4x4
(
ex
=
atom_positions
[
'C'
]
-
atom_positions
[
'CA'
],
ey
=
atom_positions
[
'CA'
]
-
atom_positions
[
'N'
],
translation
=
atom_positions
[
'C'
])
restype_rigid_group_default_frame
[
restype
,
3
,
:,
:]
=
mat
# chi1-frame to backbone
if
chi_angles_mask
[
restype
][
0
]:
base_atom_names
=
chi_angles_atoms
[
resname
][
0
]
base_atom_positions
=
[
atom_positions
[
name
]
for
name
in
base_atom_names
]
mat
=
_make_rigid_transformation_4x4
(
ex
=
base_atom_positions
[
2
]
-
base_atom_positions
[
1
],
ey
=
base_atom_positions
[
0
]
-
base_atom_positions
[
1
],
translation
=
base_atom_positions
[
2
])
restype_rigid_group_default_frame
[
restype
,
4
,
:,
:]
=
mat
# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for
chi_idx
in
range
(
1
,
4
):
if
chi_angles_mask
[
restype
][
chi_idx
]:
axis_end_atom_name
=
chi_angles_atoms
[
resname
][
chi_idx
][
2
]
axis_end_atom_position
=
atom_positions
[
axis_end_atom_name
]
mat
=
_make_rigid_transformation_4x4
(
mat
=
_make_rigid_transformation_4x4
(
ex
=
axis_end_atom_position
,
ex
=
atom_positions
[
"N"
]
-
atom_positions
[
"CA"
],
ey
=
np
.
array
([
-
1.
,
0.
,
0.
]),
ey
=
np
.
array
([
1.0
,
0.0
,
0.0
]),
translation
=
axis_end_atom_position
)
translation
=
atom_positions
[
"N"
],
restype_rigid_group_default_frame
[
restype
,
4
+
chi_idx
,
:,
:]
=
mat
)
restype_rigid_group_default_frame
[
restype
,
2
,
:,
:]
=
mat
# psi-frame to backbone
mat
=
_make_rigid_transformation_4x4
(
ex
=
atom_positions
[
"C"
]
-
atom_positions
[
"CA"
],
ey
=
atom_positions
[
"CA"
]
-
atom_positions
[
"N"
],
translation
=
atom_positions
[
"C"
],
)
restype_rigid_group_default_frame
[
restype
,
3
,
:,
:]
=
mat
# chi1-frame to backbone
if
chi_angles_mask
[
restype
][
0
]:
base_atom_names
=
chi_angles_atoms
[
resname
][
0
]
base_atom_positions
=
[
atom_positions
[
name
]
for
name
in
base_atom_names
]
mat
=
_make_rigid_transformation_4x4
(
ex
=
base_atom_positions
[
2
]
-
base_atom_positions
[
1
],
ey
=
base_atom_positions
[
0
]
-
base_atom_positions
[
1
],
translation
=
base_atom_positions
[
2
],
)
restype_rigid_group_default_frame
[
restype
,
4
,
:,
:]
=
mat
# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for
chi_idx
in
range
(
1
,
4
):
if
chi_angles_mask
[
restype
][
chi_idx
]:
axis_end_atom_name
=
chi_angles_atoms
[
resname
][
chi_idx
][
2
]
axis_end_atom_position
=
atom_positions
[
axis_end_atom_name
]
mat
=
_make_rigid_transformation_4x4
(
ex
=
axis_end_atom_position
,
ey
=
np
.
array
([
-
1.0
,
0.0
,
0.0
]),
translation
=
axis_end_atom_position
,
)
restype_rigid_group_default_frame
[
restype
,
4
+
chi_idx
,
:,
:
]
=
mat
_make_rigid_group_constants
()
_make_rigid_group_constants
()
def
make_atom14_dists_bounds
(
overlap_tolerance
=
1.5
,
def
make_atom14_dists_bounds
(
bond_length_tolerance_factor
=
15
):
overlap_tolerance
=
1.5
,
bond_length_tolerance_factor
=
15
"""compute upper and lower bounds for bonds to assess violations."""
):
restype_atom14_bond_lower_bound
=
np
.
zeros
([
21
,
14
,
14
],
np
.
float32
)
"""compute upper and lower bounds for bonds to assess violations."""
restype_atom14_bond_upper_bound
=
np
.
zeros
([
21
,
14
,
14
],
np
.
float32
)
restype_atom14_bond_lower_bound
=
np
.
zeros
([
21
,
14
,
14
],
np
.
float32
)
restype_atom14_bond_stddev
=
np
.
zeros
([
21
,
14
,
14
],
np
.
float32
)
restype_atom14_bond_upper_bound
=
np
.
zeros
([
21
,
14
,
14
],
np
.
float32
)
residue_bonds
,
residue_virtual_bonds
,
_
=
load_stereo_chemical_props
()
restype_atom14_bond_stddev
=
np
.
zeros
([
21
,
14
,
14
],
np
.
float32
)
for
restype
,
restype_letter
in
enumerate
(
restypes
):
residue_bonds
,
residue_virtual_bonds
,
_
=
load_stereo_chemical_props
()
resname
=
restype_1to3
[
restype_letter
]
for
restype
,
restype_letter
in
enumerate
(
restypes
):
atom_list
=
restype_name_to_atom14_names
[
resname
]
resname
=
restype_1to3
[
restype_letter
]
atom_list
=
restype_name_to_atom14_names
[
resname
]
# create lower and upper bounds for clashes
for
atom1_idx
,
atom1_name
in
enumerate
(
atom_list
):
# create lower and upper bounds for clashes
if
not
atom1_name
:
for
atom1_idx
,
atom1_name
in
enumerate
(
atom_list
):
continue
if
not
atom1_name
:
atom1_radius
=
van_der_waals_radius
[
atom1_name
[
0
]]
continue
for
atom2_idx
,
atom2_name
in
enumerate
(
atom_list
):
atom1_radius
=
van_der_waals_radius
[
atom1_name
[
0
]]
if
(
not
atom2_name
)
or
atom1_idx
==
atom2_idx
:
for
atom2_idx
,
atom2_name
in
enumerate
(
atom_list
):
continue
if
(
not
atom2_name
)
or
atom1_idx
==
atom2_idx
:
atom2_radius
=
van_der_waals_radius
[
atom2_name
[
0
]]
continue
lower
=
atom1_radius
+
atom2_radius
-
overlap_tolerance
atom2_radius
=
van_der_waals_radius
[
atom2_name
[
0
]]
upper
=
1e10
lower
=
atom1_radius
+
atom2_radius
-
overlap_tolerance
restype_atom14_bond_lower_bound
[
restype
,
atom1_idx
,
atom2_idx
]
=
lower
upper
=
1e10
restype_atom14_bond_lower_bound
[
restype
,
atom2_idx
,
atom1_idx
]
=
lower
restype_atom14_bond_lower_bound
[
restype_atom14_bond_upper_bound
[
restype
,
atom1_idx
,
atom2_idx
]
=
upper
restype
,
atom1_idx
,
atom2_idx
restype_atom14_bond_upper_bound
[
restype
,
atom2_idx
,
atom1_idx
]
=
upper
]
=
lower
restype_atom14_bond_lower_bound
[
# overwrite lower and upper bounds for bonds and angles
restype
,
atom2_idx
,
atom1_idx
for
b
in
residue_bonds
[
resname
]
+
residue_virtual_bonds
[
resname
]:
]
=
lower
atom1_idx
=
atom_list
.
index
(
b
.
atom1_name
)
restype_atom14_bond_upper_bound
[
atom2_idx
=
atom_list
.
index
(
b
.
atom2_name
)
restype
,
atom1_idx
,
atom2_idx
lower
=
b
.
length
-
bond_length_tolerance_factor
*
b
.
stddev
]
=
upper
upper
=
b
.
length
+
bond_length_tolerance_factor
*
b
.
stddev
restype_atom14_bond_upper_bound
[
restype_atom14_bond_lower_bound
[
restype
,
atom1_idx
,
atom2_idx
]
=
lower
restype
,
atom2_idx
,
atom1_idx
restype_atom14_bond_lower_bound
[
restype
,
atom2_idx
,
atom1_idx
]
=
lower
]
=
upper
restype_atom14_bond_upper_bound
[
restype
,
atom1_idx
,
atom2_idx
]
=
upper
restype_atom14_bond_upper_bound
[
restype
,
atom2_idx
,
atom1_idx
]
=
upper
# overwrite lower and upper bounds for bonds and angles
restype_atom14_bond_stddev
[
restype
,
atom1_idx
,
atom2_idx
]
=
b
.
stddev
for
b
in
residue_bonds
[
resname
]
+
residue_virtual_bonds
[
resname
]:
restype_atom14_bond_stddev
[
restype
,
atom2_idx
,
atom1_idx
]
=
b
.
stddev
atom1_idx
=
atom_list
.
index
(
b
.
atom1_name
)
return
{
'lower_bound'
:
restype_atom14_bond_lower_bound
,
# shape (21,14,14)
atom2_idx
=
atom_list
.
index
(
b
.
atom2_name
)
'upper_bound'
:
restype_atom14_bond_upper_bound
,
# shape (21,14,14)
lower
=
b
.
length
-
bond_length_tolerance_factor
*
b
.
stddev
'stddev'
:
restype_atom14_bond_stddev
,
# shape (21,14,14)
upper
=
b
.
length
+
bond_length_tolerance_factor
*
b
.
stddev
}
restype_atom14_bond_lower_bound
[
restype
,
atom1_idx
,
atom2_idx
]
=
lower
restype_atom14_bond_lower_bound
[
restype
,
atom2_idx
,
atom1_idx
]
=
lower
restype_atom14_bond_upper_bound
[
restype
,
atom1_idx
,
atom2_idx
]
=
upper
restype_atom14_bond_upper_bound
[
restype
,
atom2_idx
,
atom1_idx
]
=
upper
restype_atom14_bond_stddev
[
restype
,
atom1_idx
,
atom2_idx
]
=
b
.
stddev
restype_atom14_bond_stddev
[
restype
,
atom2_idx
,
atom1_idx
]
=
b
.
stddev
return
{
"lower_bound"
:
restype_atom14_bond_lower_bound
,
# shape (21,14,14)
"upper_bound"
:
restype_atom14_bond_upper_bound
,
# shape (21,14,14)
"stddev"
:
restype_atom14_bond_stddev
,
# shape (21,14,14)
}
restype_atom14_ambiguous_atoms
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
restype_atom14_ambiguous_atoms
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
restype_atom14_ambiguous_atoms_swap_idx
=
(
restype_atom14_ambiguous_atoms_swap_idx
=
np
.
tile
(
np
.
tile
(
np
.
arange
(
14
,
dtype
=
np
.
int
),
(
21
,
1
)
)
np
.
arange
(
14
,
dtype
=
np
.
int
),
(
21
,
1
)
)
)
def
_make_atom14_ambiguity_feats
():
def
_make_atom14_ambiguity_feats
():
for
res
,
pairs
in
residue_atom_renaming_swaps
.
items
():
for
res
,
pairs
in
residue_atom_renaming_swaps
.
items
():
res_idx
=
restype_order
[
restype_3to1
[
res
]]
res_idx
=
restype_order
[
restype_3to1
[
res
]]
for
atom1
,
atom2
in
pairs
.
items
():
for
atom1
,
atom2
in
pairs
.
items
():
atom1_idx
=
restype_name_to_atom14_names
[
res
].
index
(
atom1
)
atom1_idx
=
restype_name_to_atom14_names
[
res
].
index
(
atom1
)
atom2_idx
=
restype_name_to_atom14_names
[
res
].
index
(
atom2
)
atom2_idx
=
restype_name_to_atom14_names
[
res
].
index
(
atom2
)
restype_atom14_ambiguous_atoms
[
res_idx
,
atom1_idx
]
=
1
restype_atom14_ambiguous_atoms
[
res_idx
,
atom1_idx
]
=
1
restype_atom14_ambiguous_atoms
[
res_idx
,
atom2_idx
]
=
1
restype_atom14_ambiguous_atoms
[
res_idx
,
atom2_idx
]
=
1
restype_atom14_ambiguous_atoms_swap_idx
[
res_idx
,
atom1_idx
]
=
(
restype_atom14_ambiguous_atoms_swap_idx
[
atom2_idx
res_idx
,
atom1_idx
)
]
=
atom2_idx
restype_atom14_ambiguous_atoms_swap_idx
[
res_idx
,
atom2_idx
]
=
(
restype_atom14_ambiguous_atoms_swap_idx
[
atom1_idx
res_idx
,
atom2_idx
)
]
=
atom1_idx
_make_atom14_ambiguity_feats
()
_make_atom14_ambiguity_feats
()
openfold/utils/__init__.py
View file @
07e64267
...
@@ -3,13 +3,14 @@ import glob
...
@@ -3,13 +3,14 @@ import glob
import
importlib
as
importlib
import
importlib
as
importlib
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)]
__all__
=
[
_modules
=
[(
m
,
importlib
.
import_module
(
'.'
+
m
,
__name__
))
for
m
in
__all__
]
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
del
_files
,
_m
,
_modules
openfold/utils/affine_utils.py
View file @
07e64267
...
@@ -18,21 +18,48 @@ import torch
...
@@ -18,21 +18,48 @@ import torch
def
rot_matmul
(
a
,
b
):
def
rot_matmul
(
a
,
b
):
row_1
=
torch
.
stack
([
row_1
=
torch
.
stack
(
a
[...,
0
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
0
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
0
,
2
]
*
b
[...,
2
,
0
],
[
a
[...,
0
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
0
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
0
,
2
]
*
b
[...,
2
,
1
],
a
[...,
0
,
0
]
*
b
[...,
0
,
0
]
a
[...,
0
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
0
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
0
,
2
]
*
b
[...,
2
,
2
],
+
a
[...,
0
,
1
]
*
b
[...,
1
,
0
]
],
dim
=-
1
)
+
a
[...,
0
,
2
]
*
b
[...,
2
,
0
],
row_2
=
torch
.
stack
([
a
[...,
0
,
0
]
*
b
[...,
0
,
1
]
a
[...,
1
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
0
],
+
a
[...,
0
,
1
]
*
b
[...,
1
,
1
]
a
[...,
1
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
1
],
+
a
[...,
0
,
2
]
*
b
[...,
2
,
1
],
a
[...,
1
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
2
],
a
[...,
0
,
0
]
*
b
[...,
0
,
2
]
],
dim
=-
1
)
+
a
[...,
0
,
1
]
*
b
[...,
1
,
2
]
row_3
=
torch
.
stack
([
+
a
[...,
0
,
2
]
*
b
[...,
2
,
2
],
a
[...,
2
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
0
],
],
a
[...,
2
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
1
],
dim
=-
1
,
a
[...,
2
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
2
],
)
],
dim
=-
1
)
row_2
=
torch
.
stack
(
[
a
[...,
1
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
0
],
a
[...,
1
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
1
],
a
[...,
1
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
2
],
],
dim
=-
1
,
)
row_3
=
torch
.
stack
(
[
a
[...,
2
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
0
],
a
[...,
2
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
1
],
a
[...,
2
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
2
],
],
dim
=-
1
,
)
return
torch
.
stack
([
row_1
,
row_2
,
row_3
],
dim
=-
2
)
return
torch
.
stack
([
row_1
,
row_2
,
row_3
],
dim
=-
2
)
...
@@ -41,52 +68,56 @@ def rot_vec_mul(r, t):
...
@@ -41,52 +68,56 @@ def rot_vec_mul(r, t):
x
=
t
[...,
0
]
x
=
t
[...,
0
]
y
=
t
[...,
1
]
y
=
t
[...,
1
]
z
=
t
[...,
2
]
z
=
t
[...,
2
]
return
torch
.
stack
([
return
torch
.
stack
(
r
[...,
0
,
0
]
*
x
+
r
[...,
0
,
1
]
*
y
+
r
[...,
0
,
2
]
*
z
,
[
r
[...,
1
,
0
]
*
x
+
r
[...,
1
,
1
]
*
y
+
r
[...,
1
,
2
]
*
z
,
r
[...,
0
,
0
]
*
x
+
r
[...,
0
,
1
]
*
y
+
r
[...,
0
,
2
]
*
z
,
r
[...,
2
,
0
]
*
x
+
r
[...,
2
,
1
]
*
y
+
r
[...,
2
,
2
]
*
z
,
r
[...,
1
,
0
]
*
x
+
r
[...,
1
,
1
]
*
y
+
r
[...,
1
,
2
]
*
z
,
],
dim
=-
1
)
r
[...,
2
,
0
]
*
x
+
r
[...,
2
,
1
]
*
y
+
r
[...,
2
,
2
]
*
z
,
],
dim
=-
1
,
)
class
T
:
class
T
:
def
__init__
(
self
,
rots
,
trans
):
def
__init__
(
self
,
rots
,
trans
):
self
.
rots
=
rots
self
.
rots
=
rots
self
.
trans
=
trans
self
.
trans
=
trans
if
(
self
.
rots
is
None
and
self
.
trans
is
None
)
:
if
self
.
rots
is
None
and
self
.
trans
is
None
:
raise
ValueError
(
"Only one of rots and trans can be None"
)
raise
ValueError
(
"Only one of rots and trans can be None"
)
elif
(
self
.
rots
is
None
)
:
elif
self
.
rots
is
None
:
self
.
rots
=
T
.
identity_rot
(
self
.
rots
=
T
.
identity_rot
(
self
.
trans
.
shape
[:
-
1
],
self
.
trans
.
shape
[:
-
1
],
self
.
trans
.
dtype
,
self
.
trans
.
dtype
,
self
.
trans
.
device
,
self
.
trans
.
device
,
self
.
trans
.
requires_grad
,
self
.
trans
.
requires_grad
,
)
)
elif
(
self
.
trans
is
None
)
:
elif
self
.
trans
is
None
:
self
.
trans
=
T
.
identity_trans
(
self
.
trans
=
T
.
identity_trans
(
self
.
rots
.
shape
[:
-
2
],
self
.
rots
.
shape
[:
-
2
],
self
.
rots
.
dtype
,
self
.
rots
.
dtype
,
self
.
rots
.
device
,
self
.
rots
.
device
,
self
.
rots
.
requires_grad
self
.
rots
.
requires_grad
,
)
)
if
(
self
.
rots
.
shape
[
-
2
:]
!=
(
3
,
3
)
or
if
(
self
.
trans
.
shape
[
-
1
]
!=
3
or
self
.
rots
.
shape
[
-
2
:]
!=
(
3
,
3
)
self
.
rots
.
shape
[:
-
2
]
!=
self
.
trans
.
shape
[:
-
1
]):
or
self
.
trans
.
shape
[
-
1
]
!=
3
or
self
.
rots
.
shape
[:
-
2
]
!=
self
.
trans
.
shape
[:
-
1
]
):
raise
ValueError
(
"Incorrectly shaped input"
)
raise
ValueError
(
"Incorrectly shaped input"
)
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
if
(
type
(
index
)
!=
tuple
)
:
if
type
(
index
)
!=
tuple
:
index
=
(
index
,)
index
=
(
index
,)
return
T
(
return
T
(
self
.
rots
[
index
+
(
slice
(
None
),
slice
(
None
))],
self
.
rots
[
index
+
(
slice
(
None
),
slice
(
None
))],
self
.
trans
[
index
+
(
slice
(
None
),)]
self
.
trans
[
index
+
(
slice
(
None
),)]
,
)
)
def
__eq__
(
self
,
obj
):
def
__eq__
(
self
,
obj
):
return
(
return
torch
.
all
(
self
.
rots
==
obj
.
rots
)
and
torch
.
all
(
torch
.
all
(
self
.
rots
==
obj
.
rots
)
and
self
.
trans
==
obj
.
trans
torch
.
all
(
self
.
trans
==
obj
.
trans
)
)
)
def
__mul__
(
self
,
right
):
def
__mul__
(
self
,
right
):
...
@@ -135,7 +166,7 @@ class T:
...
@@ -135,7 +166,7 @@ class T:
return
T
(
rot_inv
,
-
1
*
trn_inv
)
return
T
(
rot_inv
,
-
1
*
trn_inv
)
def
unsqueeze
(
self
,
dim
):
def
unsqueeze
(
self
,
dim
):
if
(
dim
>=
len
(
self
.
shape
)
)
:
if
dim
>=
len
(
self
.
shape
):
raise
ValueError
(
"Invalid dimension"
)
raise
ValueError
(
"Invalid dimension"
)
rots
=
self
.
rots
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
2
)
rots
=
self
.
rots
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
2
)
trans
=
self
.
trans
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
1
)
trans
=
self
.
trans
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
1
)
...
@@ -155,17 +186,14 @@ class T:
...
@@ -155,17 +186,14 @@ class T:
@
staticmethod
@
staticmethod
def
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
):
def
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
):
trans
=
torch
.
zeros
(
trans
=
torch
.
zeros
(
(
*
shape
,
3
),
(
*
shape
,
3
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
dtype
=
dtype
,
)
device
=
device
,
requires_grad
=
requires_grad
)
return
trans
return
trans
@
staticmethod
@
staticmethod
def
identity
(
shape
,
dtype
,
device
,
requires_grad
=
True
):
def
identity
(
shape
,
dtype
,
device
,
requires_grad
=
True
):
return
T
(
return
T
(
T
.
identity_rot
(
shape
,
dtype
,
device
,
requires_grad
),
T
.
identity_rot
(
shape
,
dtype
,
device
,
requires_grad
),
T
.
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
),
T
.
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
),
)
)
...
@@ -191,7 +219,7 @@ class T:
...
@@ -191,7 +219,7 @@ class T:
p_neg_x_axis
=
torch
.
unbind
(
p_neg_x_axis
,
dim
=-
1
)
p_neg_x_axis
=
torch
.
unbind
(
p_neg_x_axis
,
dim
=-
1
)
origin
=
torch
.
unbind
(
origin
,
dim
=-
1
)
origin
=
torch
.
unbind
(
origin
,
dim
=-
1
)
p_xy_plane
=
torch
.
unbind
(
p_xy_plane
,
dim
=-
1
)
p_xy_plane
=
torch
.
unbind
(
p_xy_plane
,
dim
=-
1
)
e0
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
origin
,
p_neg_x_axis
)]
e0
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
origin
,
p_neg_x_axis
)]
e1
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
p_xy_plane
,
origin
)]
e1
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
p_xy_plane
,
origin
)]
...
@@ -209,35 +237,31 @@ class T:
...
@@ -209,35 +237,31 @@ class T:
rots
=
torch
.
stack
([
c
for
tup
in
zip
(
e0
,
e1
,
e2
)
for
c
in
tup
],
dim
=-
1
)
rots
=
torch
.
stack
([
c
for
tup
in
zip
(
e0
,
e1
,
e2
)
for
c
in
tup
],
dim
=-
1
)
rots
=
rots
.
reshape
(
rots
.
shape
[:
-
1
]
+
(
3
,
3
))
rots
=
rots
.
reshape
(
rots
.
shape
[:
-
1
]
+
(
3
,
3
))
return
T
(
rots
,
torch
.
stack
(
origin
,
dim
=-
1
))
return
T
(
rots
,
torch
.
stack
(
origin
,
dim
=-
1
))
@
staticmethod
@
staticmethod
def
concat
(
ts
,
dim
):
def
concat
(
ts
,
dim
):
rots
=
torch
.
cat
(
rots
=
torch
.
cat
([
t
.
rots
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
2
)
[
t
.
rots
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
2
)
trans
=
torch
.
cat
(
trans
=
torch
.
cat
(
[
t
.
trans
for
t
in
ts
],
[
t
.
trans
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
1
dim
=
dim
if
dim
>=
0
else
dim
-
1
)
)
return
T
(
rots
,
trans
)
return
T
(
rots
,
trans
)
def
map_tensor_fn
(
self
,
fn
):
def
map_tensor_fn
(
self
,
fn
):
"""
"""
Apply a function that takes a tensor as its only argument to the
Apply a function that takes a tensor as its only argument to the
rotations and translations, treating the final two/one
rotations and translations, treating the final two/one
dimension(s), respectively, as batch dimensions.
dimension(s), respectively, as batch dimensions.
E.g.: Given t, an instance of T of shape [N, M], this function can
E.g.: Given t, an instance of T of shape [N, M], this function can
be used to sum out the second dimension thereof as follows:
be used to sum out the second dimension thereof as follows:
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
The resulting object has rotations of shape [N, 3, 3] and
The resulting object has rotations of shape [N, 3, 3] and
translations of shape [N, 3]
translations of shape [N, 3]
"""
"""
rots
=
self
.
rots
.
view
(
*
self
.
rots
.
shape
[:
-
2
],
9
)
rots
=
self
.
rots
.
view
(
*
self
.
rots
.
shape
[:
-
2
],
9
)
rots
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
rots
,
-
1
))),
dim
=-
1
)
rots
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
rots
,
-
1
))),
dim
=-
1
)
...
@@ -260,7 +284,7 @@ class T:
...
@@ -260,7 +284,7 @@ class T:
c_xyz
=
c_xyz
+
translation
c_xyz
=
c_xyz
+
translation
c_x
,
c_y
,
c_z
=
[
c_xyz
[...,
i
]
for
i
in
range
(
3
)]
c_x
,
c_y
,
c_z
=
[
c_xyz
[...,
i
]
for
i
in
range
(
3
)]
norm
=
torch
.
sqrt
(
eps
+
c_x
**
2
+
c_y
**
2
)
norm
=
torch
.
sqrt
(
eps
+
c_x
**
2
+
c_y
**
2
)
sin_c1
=
-
c_y
/
norm
sin_c1
=
-
c_y
/
norm
cos_c1
=
c_x
/
norm
cos_c1
=
c_x
/
norm
zeros
=
sin_c1
.
new_zeros
(
sin_c1
.
shape
)
zeros
=
sin_c1
.
new_zeros
(
sin_c1
.
shape
)
...
@@ -273,9 +297,9 @@ class T:
...
@@ -273,9 +297,9 @@ class T:
c1_rots
[...,
1
,
1
]
=
cos_c1
c1_rots
[...,
1
,
1
]
=
cos_c1
c1_rots
[...,
2
,
2
]
=
1
c1_rots
[...,
2
,
2
]
=
1
norm
=
torch
.
sqrt
(
eps
+
c_x
**
2
+
c_y
**
2
+
c_z
**
2
)
norm
=
torch
.
sqrt
(
eps
+
c_x
**
2
+
c_y
**
2
+
c_z
**
2
)
sin_c2
=
c_z
/
norm
sin_c2
=
c_z
/
norm
cos_c2
=
torch
.
sqrt
(
c_x
**
2
+
c_y
**
2
)
/
norm
cos_c2
=
torch
.
sqrt
(
c_x
**
2
+
c_y
**
2
)
/
norm
c2_rots
=
sin_c2
.
new_zeros
((
*
sin_c2
.
shape
,
3
,
3
))
c2_rots
=
sin_c2
.
new_zeros
((
*
sin_c2
.
shape
,
3
,
3
))
c2_rots
[...,
0
,
0
]
=
cos_c2
c2_rots
[...,
0
,
0
]
=
cos_c2
...
@@ -288,14 +312,14 @@ class T:
...
@@ -288,14 +312,14 @@ class T:
n_xyz
=
rot_vec_mul
(
c_rots
,
n_xyz
)
n_xyz
=
rot_vec_mul
(
c_rots
,
n_xyz
)
_
,
n_y
,
n_z
=
[
n_xyz
[...,
i
]
for
i
in
range
(
3
)]
_
,
n_y
,
n_z
=
[
n_xyz
[...,
i
]
for
i
in
range
(
3
)]
norm
=
torch
.
sqrt
(
eps
+
n_y
**
2
+
n_z
**
2
)
norm
=
torch
.
sqrt
(
eps
+
n_y
**
2
+
n_z
**
2
)
sin_n
=
-
n_z
/
norm
sin_n
=
-
n_z
/
norm
cos_n
=
n_y
/
norm
cos_n
=
n_y
/
norm
n_rots
=
sin_c2
.
new_zeros
((
*
sin_c2
.
shape
,
3
,
3
))
n_rots
=
sin_c2
.
new_zeros
((
*
sin_c2
.
shape
,
3
,
3
))
n_rots
[...,
0
,
0
]
=
1
n_rots
[...,
0
,
0
]
=
1
n_rots
[...,
1
,
1
]
=
cos_n
n_rots
[...,
1
,
1
]
=
cos_n
n_rots
[...,
1
,
2
]
=
-
1
*
sin_n
n_rots
[...,
1
,
2
]
=
-
1
*
sin_n
n_rots
[...,
2
,
1
]
=
sin_n
n_rots
[...,
2
,
1
]
=
sin_n
n_rots
[...,
2
,
2
]
=
cos_n
n_rots
[...,
2
,
2
]
=
cos_n
...
@@ -309,10 +333,11 @@ class T:
...
@@ -309,10 +333,11 @@ class T:
def
cuda
(
self
):
def
cuda
(
self
):
return
T
(
self
.
rots
.
cuda
(),
self
.
trans
.
cuda
())
return
T
(
self
.
rots
.
cuda
(),
self
.
trans
.
cuda
())
_quat_elements
=
[
'a'
,
'b'
,
'c'
,
'd'
]
_quat_elements
=
[
"a"
,
"b"
,
"c"
,
"d"
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
_qtr_ind_dict
=
{
key
:
ind
for
ind
,
key
in
enumerate
(
_qtr_keys
)}
_qtr_ind_dict
=
{
key
:
ind
for
ind
,
key
in
enumerate
(
_qtr_keys
)}
def
_to_mat
(
pairs
):
def
_to_mat
(
pairs
):
mat
=
torch
.
zeros
((
4
,
4
))
mat
=
torch
.
zeros
((
4
,
4
))
...
@@ -323,20 +348,20 @@ def _to_mat(pairs):
...
@@ -323,20 +348,20 @@ def _to_mat(pairs):
return
mat
return
mat
_qtr_mat
=
np
.
zeros
((
4
,
4
,
3
,
3
))
_qtr_mat
=
np
.
zeros
((
4
,
4
,
3
,
3
))
_qtr_mat
[...,
0
,
0
]
=
_to_mat
([(
'aa'
,
1
),
(
'bb'
,
1
),
(
'cc'
,
-
1
),
(
'dd'
,
-
1
)])
_qtr_mat
[...,
0
,
0
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
1
),
(
"cc"
,
-
1
),
(
"dd"
,
-
1
)])
_qtr_mat
[...,
0
,
1
]
=
_to_mat
([(
'bc'
,
2
),
(
'ad'
,
-
2
)])
_qtr_mat
[...,
0
,
1
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
-
2
)])
_qtr_mat
[...,
0
,
2
]
=
_to_mat
([(
'bd'
,
2
),
(
'ac'
,
2
)])
_qtr_mat
[...,
0
,
2
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
2
)])
_qtr_mat
[...,
1
,
0
]
=
_to_mat
([(
'bc'
,
2
),
(
'ad'
,
2
)])
_qtr_mat
[...,
1
,
0
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
2
)])
_qtr_mat
[...,
1
,
1
]
=
_to_mat
([(
'aa'
,
1
),
(
'bb'
,
-
1
),
(
'cc'
,
1
),
(
'dd'
,
-
1
)])
_qtr_mat
[...,
1
,
1
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
1
),
(
"dd"
,
-
1
)])
_qtr_mat
[...,
1
,
2
]
=
_to_mat
([(
'cd'
,
2
),
(
'ab'
,
-
2
)])
_qtr_mat
[...,
1
,
2
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
-
2
)])
_qtr_mat
[...,
2
,
0
]
=
_to_mat
([(
'bd'
,
2
),
(
'ac'
,
-
2
)])
_qtr_mat
[...,
2
,
0
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
-
2
)])
_qtr_mat
[...,
2
,
1
]
=
_to_mat
([(
'cd'
,
2
),
(
'ab'
,
2
)])
_qtr_mat
[...,
2
,
1
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
2
)])
_qtr_mat
[...,
2
,
2
]
=
_to_mat
([(
'aa'
,
1
),
(
'bb'
,
-
1
),
(
'cc'
,
-
1
),
(
'dd'
,
1
)])
_qtr_mat
[...,
2
,
2
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
-
1
),
(
"dd"
,
1
)])
def
quat_to_rot
(
quat
# [*, 4]
def
quat_to_rot
(
quat
):
# [*, 4]
):
# [*, 4, 4]
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
...
@@ -350,6 +375,7 @@ def quat_to_rot(
...
@@ -350,6 +375,7 @@ def quat_to_rot(
# [*, 3, 3]
# [*, 3, 3]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
def
affine_vector_to_4x4
(
vector
):
def
affine_vector_to_4x4
(
vector
):
quats
=
vector
[...,
:
4
]
quats
=
vector
[...,
:
4
]
trans
=
vector
[...,
4
:]
trans
=
vector
[...,
4
:]
...
...
openfold/utils/deepspeed.py
View file @
07e64267
...
@@ -20,31 +20,32 @@ from typing import Any, Tuple, List, Callable
...
@@ -20,31 +20,32 @@ from typing import Any, Tuple, List, Callable
BLOCK_ARG
=
Any
BLOCK_ARG
=
Any
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
def
checkpoint_blocks
(
def
checkpoint_blocks
(
blocks
:
List
[
Callable
],
blocks
:
List
[
Callable
],
args
:
BLOCK_ARGS
,
args
:
BLOCK_ARGS
,
blocks_per_ckpt
:
int
,
blocks_per_ckpt
:
int
,
)
->
BLOCK_ARGS
:
)
->
BLOCK_ARGS
:
"""
"""
Chunk a list of blocks and run each chunk with activation
Chunk a list of blocks and run each chunk with activation
checkpointing. We define a "block" as a callable whose only inputs are
checkpointing. We define a "block" as a callable whose only inputs are
the outputs of the previous block.
the outputs of the previous block.
This function assumes that deepspeed has already been initialized.
This function assumes that deepspeed has already been initialized.
Implements Subsection 1.11.8
Implements Subsection 1.11.8
Args:
Args:
blocks:
blocks:
List of blocks
List of blocks
args:
args:
Tuple of arguments for the first block.
Tuple of arguments for the first block.
blocks_per_ckpt:
blocks_per_ckpt:
Size of each chunk. A higher value corresponds to higher memory
Size of each chunk. A higher value corresponds to higher memory
consumption but fewer checkpoints. If None, no checkpointing is
consumption but fewer checkpoints. If None, no checkpointing is
performed.
performed.
Returns:
Returns:
The output of the final block
The output of the final block
"""
"""
def
wrap
(
a
):
def
wrap
(
a
):
...
@@ -58,19 +59,20 @@ def checkpoint_blocks(
...
@@ -58,19 +59,20 @@ def checkpoint_blocks(
def
chunker
(
s
,
e
):
def
chunker
(
s
,
e
):
def
exec_sliced
(
*
a
):
def
exec_sliced
(
*
a
):
return
exec
(
blocks
[
s
:
e
],
a
)
return
exec
(
blocks
[
s
:
e
],
a
)
return
exec_sliced
return
exec_sliced
# Avoids mishaps when the blocks take just one argument
# Avoids mishaps when the blocks take just one argument
args
=
wrap
(
args
)
args
=
wrap
(
args
)
if
(
blocks_per_ckpt
is
None
)
:
if
blocks_per_ckpt
is
None
:
return
exec
(
blocks
,
args
)
return
exec
(
blocks
,
args
)
elif
(
blocks_per_ckpt
<
1
or
blocks_per_ckpt
>
len
(
blocks
)
)
:
elif
blocks_per_ckpt
<
1
or
blocks_per_ckpt
>
len
(
blocks
):
raise
ValueError
(
"blocks_per_ckpt must be between 1 and len(blocks)"
)
raise
ValueError
(
"blocks_per_ckpt must be between 1 and len(blocks)"
)
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
e
=
s
+
blocks_per_ckpt
#args = checkpoint(chunker(s, e), *args)
#
args = checkpoint(chunker(s, e), *args)
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
*
args
)
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
*
args
)
args
=
wrap
(
args
)
args
=
wrap
(
args
)
...
...
openfold/utils/exponential_moving_average.py
View file @
07e64267
...
@@ -3,26 +3,28 @@ import copy
...
@@ -3,26 +3,28 @@ import copy
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
class
ExponentialMovingAverage
:
class
ExponentialMovingAverage
:
"""
"""
Maintains moving averages of parameters with exponential decay
Maintains moving averages of parameters with exponential decay
At each step, the stored copy `copy` of each parameter `param` is
updated as follows:
At each step, the stored copy `copy` of each parameter `param` is
`copy = decay * copy + (1 - decay) * param`
updated as follows:
`copy = decay * copy + (1 - decay) * param`
where `decay` is an attribute of the ExponentialMovingAverage object.
where `decay` is an attribute of the ExponentialMovingAverage object.
"""
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
decay
:
float
):
def
__init__
(
self
,
model
:
nn
.
Module
,
decay
:
float
):
"""
"""
Args:
Args:
model:
model:
A torch.nn.Module whose parameters are to be tracked
A torch.nn.Module whose parameters are to be tracked
decay:
decay:
A value (usually close to 1.) by which updates are
A value (usually close to 1.) by which updates are
weighted as part of the above formula
weighted as part of the above formula
"""
"""
super
(
ExponentialMovingAverage
,
self
).
__init__
()
super
(
ExponentialMovingAverage
,
self
).
__init__
()
self
.
params
=
copy
.
deepcopy
(
model
.
state_dict
())
self
.
params
=
copy
.
deepcopy
(
model
.
state_dict
())
...
@@ -32,27 +34,29 @@ class ExponentialMovingAverage:
...
@@ -32,27 +34,29 @@ class ExponentialMovingAverage:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
k
,
v
in
update
.
items
():
for
k
,
v
in
update
.
items
():
stored
=
state_dict
[
k
]
stored
=
state_dict
[
k
]
if
(
not
isinstance
(
v
,
torch
.
Tensor
)
)
:
if
not
isinstance
(
v
,
torch
.
Tensor
):
self
.
_update_state_dict_
(
v
,
stored
)
self
.
_update_state_dict_
(
v
,
stored
)
else
:
else
:
diff
=
stored
-
v
diff
=
stored
-
v
diff
*=
(
1
-
self
.
decay
)
diff
*=
1
-
self
.
decay
stored
-=
diff
stored
-=
diff
def
update
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
def
update
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
"""
"""
Updates the stored parameters using the state dict of the provided
Updates the stored parameters using the state dict of the provided
module. The module should have the same structure as that used to
module. The module should have the same structure as that used to
initialize the ExponentialMovingAverage object.
initialize the ExponentialMovingAverage object.
"""
"""
self
.
_update_state_dict_
(
model
.
state_dict
(),
self
.
params
)
self
.
_update_state_dict_
(
model
.
state_dict
(),
self
.
params
)
def
load_state_dict
(
self
,
state_dict
:
OrderedDict
)
->
None
:
def
load_state_dict
(
self
,
state_dict
:
OrderedDict
)
->
None
:
self
.
params
=
state_dict
[
"params"
]
self
.
params
=
state_dict
[
"params"
]
self
.
decay
=
state_dict
[
"decay"
]
self
.
decay
=
state_dict
[
"decay"
]
def
state_dict
(
self
)
->
OrderedDict
:
def
state_dict
(
self
)
->
OrderedDict
:
return
OrderedDict
({
return
OrderedDict
(
"params"
:
self
.
params
,
{
"decay"
:
self
.
decay
,
"params"
:
self
.
params
,
})
"decay"
:
self
.
decay
,
}
)
openfold/utils/feats.py
View file @
07e64267
...
@@ -18,10 +18,10 @@ import torch
...
@@ -18,10 +18,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Dict
from
typing
import
Dict
import
openfold.np.residue_constants
as
rc
import
openfold.np.residue_constants
as
rc
from
openfold.utils.affine_utils
import
T
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
batched_gather
,
batched_gather
,
one_hot
,
one_hot
,
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
...
@@ -29,16 +29,16 @@ from openfold.utils.tensor_utils import (
...
@@ -29,16 +29,16 @@ from openfold.utils.tensor_utils import (
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_masks
):
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_masks
):
is_gly
=
(
aatype
==
rc
.
restype_order
[
'G'
])
is_gly
=
aatype
==
rc
.
restype_order
[
"G"
]
ca_idx
=
rc
.
atom_order
[
'
CA
'
]
ca_idx
=
rc
.
atom_order
[
"
CA
"
]
cb_idx
=
rc
.
atom_order
[
'
CB
'
]
cb_idx
=
rc
.
atom_order
[
"
CB
"
]
pseudo_beta
=
torch
.
where
(
pseudo_beta
=
torch
.
where
(
is_gly
[...,
None
].
expand
(
*
((
-
1
,)
*
len
(
is_gly
.
shape
)),
3
),
is_gly
[...,
None
].
expand
(
*
((
-
1
,)
*
len
(
is_gly
.
shape
)),
3
),
all_atom_positions
[...,
ca_idx
,
:],
all_atom_positions
[...,
ca_idx
,
:],
all_atom_positions
[...,
cb_idx
,
:]
all_atom_positions
[...,
cb_idx
,
:]
,
)
)
if
(
all_atom_masks
is
not
None
)
:
if
all_atom_masks
is
not
None
:
pseudo_beta_mask
=
torch
.
where
(
pseudo_beta_mask
=
torch
.
where
(
is_gly
,
is_gly
,
all_atom_masks
[...,
ca_idx
],
all_atom_masks
[...,
ca_idx
],
...
@@ -65,9 +65,9 @@ def atom14_to_atom37(atom14, batch):
...
@@ -65,9 +65,9 @@ def atom14_to_atom37(atom14, batch):
def
build_template_angle_feat
(
template_feats
):
def
build_template_angle_feat
(
template_feats
):
template_aatype
=
template_feats
[
"template_aatype"
]
template_aatype
=
template_feats
[
"template_aatype"
]
torsion_angles_sin_cos
=
template_feats
[
"template_torsion_angles_sin_cos"
]
torsion_angles_sin_cos
=
template_feats
[
"template_torsion_angles_sin_cos"
]
alt_torsion_angles_sin_cos
=
(
alt_torsion_angles_sin_cos
=
template_feats
[
template_feats
[
"template_alt_torsion_angles_sin_cos"
]
"template_alt_torsion_angles_sin_cos"
)
]
torsion_angles_mask
=
template_feats
[
"template_torsion_angles_mask"
]
torsion_angles_mask
=
template_feats
[
"template_torsion_angles_mask"
]
template_angle_feat
=
torch
.
cat
(
template_angle_feat
=
torch
.
cat
(
[
[
...
@@ -79,21 +79,24 @@ def build_template_angle_feat(template_feats):
...
@@ -79,21 +79,24 @@ def build_template_angle_feat(template_feats):
*
alt_torsion_angles_sin_cos
.
shape
[:
-
2
],
14
*
alt_torsion_angles_sin_cos
.
shape
[:
-
2
],
14
),
),
torsion_angles_mask
,
torsion_angles_mask
,
],
],
dim
=-
1
,
dim
=-
1
,
)
)
return
template_angle_feat
return
template_angle_feat
def
build_template_pair_feat
(
batch
,
min_bin
,
max_bin
,
no_bins
,
eps
=
1e-20
,
inf
=
1e8
):
def
build_template_pair_feat
(
batch
,
min_bin
,
max_bin
,
no_bins
,
eps
=
1e-20
,
inf
=
1e8
):
template_mask
=
batch
[
"template_pseudo_beta_mask"
]
template_mask
=
batch
[
"template_pseudo_beta_mask"
]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
# Compute distogram (this seems to differ slightly from Alg. 5)
# Compute distogram (this seems to differ slightly from Alg. 5)
tpb
=
batch
[
"template_pseudo_beta"
]
tpb
=
batch
[
"template_pseudo_beta"
]
dgram
=
torch
.
sum
(
dgram
=
torch
.
sum
(
(
tpb
[...,
None
,
:]
-
tpb
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdim
=
True
)
(
tpb
[...,
None
,
:]
-
tpb
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdim
=
True
)
lower
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
,
device
=
tpb
.
device
)
**
2
lower
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
,
device
=
tpb
.
device
)
**
2
upper
=
torch
.
cat
([
lower
[:
-
1
],
lower
.
new_tensor
([
inf
])],
dim
=-
1
)
upper
=
torch
.
cat
([
lower
[:
-
1
],
lower
.
new_tensor
([
inf
])],
dim
=-
1
)
dgram
=
((
dgram
>
lower
)
*
(
dgram
<
upper
)).
type
(
dgram
.
dtype
)
dgram
=
((
dgram
>
lower
)
*
(
dgram
<
upper
)).
type
(
dgram
.
dtype
)
...
@@ -101,7 +104,8 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
...
@@ -101,7 +104,8 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
aatype_one_hot
=
nn
.
functional
.
one_hot
(
aatype_one_hot
=
nn
.
functional
.
one_hot
(
batch
[
"template_aatype"
],
rc
.
restype_num
+
2
,
batch
[
"template_aatype"
],
rc
.
restype_num
+
2
,
)
)
n_res
=
batch
[
"template_aatype"
].
shape
[
-
1
]
n_res
=
batch
[
"template_aatype"
].
shape
[
-
1
]
...
@@ -116,7 +120,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
...
@@ -116,7 +120,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
)
)
)
)
n
,
ca
,
c
=
[
rc
.
atom_order
[
a
]
for
a
in
[
'N'
,
'
CA
'
,
'C'
]]
n
,
ca
,
c
=
[
rc
.
atom_order
[
a
]
for
a
in
[
"N"
,
"
CA
"
,
"C"
]]
# TODO: Consider running this in double precision
# TODO: Consider running this in double precision
affines
=
T
.
make_transform_from_reference
(
affines
=
T
.
make_transform_from_reference
(
n_xyz
=
batch
[
"template_all_atom_positions"
][...,
n
,
:],
n_xyz
=
batch
[
"template_all_atom_positions"
][...,
n
,
:],
...
@@ -127,10 +131,8 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
...
@@ -127,10 +131,8 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
points
=
affines
.
get_trans
()[...,
None
,
:,
:]
points
=
affines
.
get_trans
()[...,
None
,
:,
:]
affine_vec
=
affines
[...,
None
].
invert_apply
(
points
)
affine_vec
=
affines
[...,
None
].
invert_apply
(
points
)
inv_distance_scalar
=
torch
.
rsqrt
(
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
))
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
)
)
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
template_mask
=
(
template_mask
=
(
...
@@ -139,10 +141,10 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
...
@@ -139,10 +141,10 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
unit_vector
=
(
affine_vec
*
inv_distance_scalar
[...,
None
]
)
unit_vector
=
affine_vec
*
inv_distance_scalar
[...,
None
]
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
append
(
template_mask_2d
[...,
None
])
to_concat
.
append
(
template_mask_2d
[...,
None
])
act
=
torch
.
cat
(
to_concat
,
dim
=-
1
)
act
=
torch
.
cat
(
to_concat
,
dim
=-
1
)
act
=
act
*
template_mask_2d
[...,
None
]
act
=
act
*
template_mask_2d
[...,
None
]
...
@@ -161,55 +163,62 @@ def build_extra_msa_feat(batch):
...
@@ -161,55 +163,62 @@ def build_extra_msa_feat(batch):
# adapted from model/tf/data_transforms.py
# adapted from model/tf/data_transforms.py
def
build_msa_feat
(
batch
):
def
build_msa_feat
(
batch
):
"""Create and concatenate MSA features."""
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping
# Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets.
# for compatibility with domain datasets.
has_break
=
batch
[
"between_segment_residues"
]
has_break
=
batch
[
"between_segment_residues"
]
aatype_1hot
=
nn
.
functional
.
one_hot
(
batch
[
'aatype'
],
num_classes
=
21
)
aatype_1hot
=
nn
.
functional
.
one_hot
(
batch
[
"aatype"
],
num_classes
=
21
)
target_feat
=
[
target_feat
=
[
has_break
.
unsqueeze
(
-
1
),
has_break
.
unsqueeze
(
-
1
),
aatype_1hot
,
# Everyone gets the original sequence.
aatype_1hot
,
# Everyone gets the original sequence.
]
]
msa_1hot
=
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
num_classes
=
23
)
msa_1hot
=
nn
.
functional
.
one_hot
(
batch
[
"msa"
],
num_classes
=
23
)
has_deletion
=
batch
[
"deletion_matrix"
]
has_deletion
=
batch
[
"deletion_matrix"
]
deletion_value
=
torch
.
atan
(
batch
[
'deletion_matrix'
]
/
3.
)
*
(
2.
/
math
.
pi
)
deletion_value
=
torch
.
atan
(
batch
[
"deletion_matrix"
]
/
3.0
)
*
(
2.0
/
math
.
pi
msa_feat
=
[
)
msa_1hot
,
has_deletion
.
unsqueeze
(
-
1
),
msa_feat
=
[
deletion_value
.
unsqueeze
(
-
1
),
msa_1hot
,
]
has_deletion
.
unsqueeze
(
-
1
),
deletion_value
.
unsqueeze
(
-
1
),
if
'cluster_profile'
in
batch
:
]
deletion_mean_value
=
(
tf
.
atan
(
batch
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
np
.
pi
))
if
"cluster_profile"
in
batch
:
msa_feat
.
extend
([
deletion_mean_value
=
tf
.
atan
(
batch
[
"cluster_deletion_mean"
]
/
3.0
)
*
(
batch
[
'cluster_profile'
],
2.0
/
np
.
pi
tf
.
expand_dims
(
deletion_mean_value
,
axis
=-
1
),
)
])
msa_feat
.
extend
(
[
if
'extra_deletion_matrix'
in
protein
:
batch
[
"cluster_profile"
],
batch
[
'extra_has_deletion'
]
=
tf
.
clip_by_value
(
tf
.
expand_dims
(
deletion_mean_value
,
axis
=-
1
),
batch
[
'extra_deletion_matrix'
],
0.
,
1.
)
]
batch
[
'extra_deletion_value'
]
=
tf
.
atan
(
)
batch
[
'extra_deletion_matrix'
]
/
3.
)
*
(
2.
/
np
.
pi
)
if
"extra_deletion_matrix"
in
protein
:
batch
[
'msa_feat'
]
=
torch
.
cat
(
msa_feat
,
dim
=-
1
)
batch
[
"extra_has_deletion"
]
=
tf
.
clip_by_value
(
batch
[
'target_feat'
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
batch
[
"extra_deletion_matrix"
],
0.0
,
1.0
return
batch
)
batch
[
"extra_deletion_value"
]
=
tf
.
atan
(
batch
[
"extra_deletion_matrix"
]
/
3.0
)
*
(
2.0
/
np
.
pi
)
batch
[
"msa_feat"
]
=
torch
.
cat
(
msa_feat
,
dim
=-
1
)
batch
[
"target_feat"
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
return
batch
def
torsion_angles_to_frames
(
def
torsion_angles_to_frames
(
t
:
T
,
t
:
T
,
alpha
:
torch
.
Tensor
,
alpha
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
):
):
# [*, N, 8, 4, 4]
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
aatype
,
...]
default_4x4
=
rrgdf
[
aatype
,
...]
# [*, N, 8] transformations, i.e.
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
# One [*, N, 8, 3] translation matrix
...
@@ -217,12 +226,9 @@ def torsion_angles_to_frames(
...
@@ -217,12 +226,9 @@ def torsion_angles_to_frames(
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
bb_rot
[...,
1
]
=
1
# [*, N, 8, 2]
# [*, N, 8, 2]
alpha
=
torch
.
cat
(
alpha
=
torch
.
cat
([
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
[
bb_rot
.
expand
(
*
alpha
.
shape
[:
-
2
],
-
1
,
-
1
),
alpha
],
dim
=-
2
)
# [*, N, 8, 3, 3]
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# Produces rotation matrices of the form:
...
@@ -233,7 +239,7 @@ def torsion_angles_to_frames(
...
@@ -233,7 +239,7 @@ def torsion_angles_to_frames(
# ]
# ]
# This follows the original code rather than the supplement, which uses
# This follows the original code rather than the supplement, which uses
# different indices.
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_t
.
rots
.
shape
)
all_rots
=
alpha
.
new_zeros
(
default_t
.
rots
.
shape
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
...
@@ -253,12 +259,14 @@ def torsion_angles_to_frames(
...
@@ -253,12 +259,14 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
T
.
concat
([
all_frames_to_bb
=
T
.
concat
(
[
all_frames
[...,
:
5
],
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
chi3_frame_to_bb
.
unsqueeze
(
-
1
),
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
chi4_frame_to_bb
.
unsqueeze
(
-
1
),
],
dim
=-
1
,
],
dim
=-
1
,
)
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
all_frames_to_global
=
t
[...,
None
].
compose
(
all_frames_to_bb
)
...
@@ -274,20 +282,21 @@ def frames_and_literature_positions_to_atom14_pos(
...
@@ -274,20 +282,21 @@ def frames_and_literature_positions_to_atom14_pos(
atom_mask
,
atom_mask
,
lit_positions
,
lit_positions
,
):
):
# [*, N, 14, 4, 4]
# [*, N, 14, 4, 4]
default_4x4
=
default_frames
[
aatype
,
...]
default_4x4
=
default_frames
[
aatype
,
...]
# [*, N, 14]
# [*, N, 14]
group_mask
=
group_idx
[
aatype
,
...]
group_mask
=
group_idx
[
aatype
,
...]
# [*, N, 14, 8]
# [*, N, 14, 8]
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
)
)
# [*, N, 14, 8]
# [*, N, 14, 8]
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
t_atoms_to_global
=
t
[...,
None
,
:]
*
group_mask
# [*, N, 14]
# [*, N, 14]
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
t_atoms_to_global
=
t_atoms_to_global
.
map_tensor_fn
(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment