Commit 07e64267 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Standardize code style

parent de07730f
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -23,17 +23,18 @@ from openfold.utils.tensor_utils import chunk_layer
class OuterProductMean(nn.Module):
"""
Implements Algorithm 10.
Implements Algorithm 10.
"""
def __init__(self, c_m, c_z, c_hidden, chunk_size=4, eps=1e-3):
"""
Args:
c_m:
MSA embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
Args:
c_m:
MSA embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super(OuterProductMean, self).__init__()
......@@ -45,12 +46,12 @@ class OuterProductMean(nn.Module):
self.layer_norm = nn.LayerNorm(c_m)
self.linear_1 = Linear(c_m, c_hidden)
self.linear_2 = Linear(c_m, c_hidden)
self.linear_out = Linear(c_hidden**2, c_z, init="final")
self.linear_out = Linear(c_hidden ** 2, c_z, init="final")
def _opm(self, a, b):
# [*, N_res, N_res, C, C]
outer = torch.einsum("...bac,...dae->...bdce", a, b)
# [*, N_res, N_res, C * C]
outer = outer.reshape(*outer.shape[:-2], -1)
......@@ -61,20 +62,20 @@ class OuterProductMean(nn.Module):
def forward(self, m, mask=None):
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if(mask is None):
if mask is None:
mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m]
m = self.layer_norm(m)
# [*, N_seq, N_res, C]
mask = mask.unsqueeze(-1)
a = self.linear_1(m) * mask
......@@ -83,7 +84,7 @@ class OuterProductMean(nn.Module):
a = a.transpose(-2, -3)
b = b.transpose(-2, -3)
if(self.chunk_size is not None):
if self.chunk_size is not None:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
......
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -22,16 +22,17 @@ from openfold.utils.tensor_utils import chunk_layer
class PairTransition(nn.Module):
"""
Implements Algorithm 15.
Implements Algorithm 15.
"""
def __init__(self, c_z, n, chunk_size=4):
"""
Args:
c_z:
Pair transition channel dimension
n:
Factor by which c_z is multiplied to obtain hidden channel
dimension
Args:
c_z:
Pair transition channel dimension
n:
Factor by which c_z is multiplied to obtain hidden channel
dimension
"""
super(PairTransition, self).__init__()
......@@ -56,14 +57,14 @@ class PairTransition(nn.Module):
def forward(self, z, mask=None):
"""
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N_res, N_res, C_z] pair embedding update
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
if(mask is None):
if mask is None:
mask = z.new_ones(z.shape[:-1])
# [*, N_res, N_res, 1]
......@@ -73,12 +74,12 @@ class PairTransition(nn.Module):
z = self.layer_norm(z)
inp = {"z": z, "mask": mask}
if(self.chunk_size is not None):
if self.chunk_size is not None:
z = chunk_layer(
self._transition,
inp,
chunk_size=self.chunk_size,
no_batch_dims=len(z.shape[:-2]),
no_batch_dims=len(z.shape[:-2]),
)
else:
z = self._transition(**inp)
......
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -14,7 +14,7 @@
# limitations under the License.
import math
from typing import Optional, Callable, List
from typing import Optional, Callable, List
import numpy as np
import torch
......@@ -22,7 +22,7 @@ import torch.nn as nn
from scipy.stats import truncnorm
from openfold.utils.tensor_utils import (
permute_final_dims,
permute_final_dims,
flatten_final_dims,
)
......@@ -33,6 +33,7 @@ def _prod(nums):
out = out * n
return out
def _calculate_fan(shape, fan="fan_in"):
i = shape[0]
o = shape[1]
......@@ -40,20 +41,20 @@ def _calculate_fan(shape, fan="fan_in"):
fan_in = prod * i
fan_out = prod * o
if(fan == "fan_in"):
if fan == "fan_in":
f = fan_in
elif(fan == "fan_out"):
elif fan == "fan_out":
f = fan_out
elif(fan == "fan_avg"):
elif fan == "fan_avg":
f = (fan_in + fan_out) / 2
else:
raise ValueError("Invalid fan option")
return f
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
shape = weights.shape
shape = weights.shape
f = _calculate_fan(shape, fan)
scale = scale / max(1, f)
a = -2
......@@ -80,17 +81,17 @@ def glorot_uniform_init_(weights):
def final_init_(weights):
with torch.no_grad():
weights.fill_(0.)
weights.fill_(0.0)
def gating_init_(weights):
with torch.no_grad():
weights.fill_(0.)
weights.fill_(0.0)
def normal_init_(weights):
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
def ipa_point_weights_init_(weights):
with torch.no_grad():
......@@ -100,98 +101,101 @@ def ipa_point_weights_init_(weights):
class Linear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def __init__(self,
in_dim: int,
out_dim: int,
bias: bool = True,
init: str = "default",
def __init__(
self,
in_dim: int,
out_dim: int,
bias: bool = True,
init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
):
"""
Args:
in_dim:
The final dimension of inputs to the layer
out_dim:
The final dimension of layer outputs
bias:
Whether to learn an additive bias. True by default
init:
The initializer to use. Choose from:
"default": LeCun fan-in truncated normal initialization
"relu": He initialization w/ truncated normal distribution
"glorot": Fan-average Glorot uniform initialization
"gating": Weights=0, Bias=1
"normal": Normal initialization with std=1/sqrt(fan_in)
"final": Weights=0, Bias=0
Overridden by init_fn if the latter is not None.
init_fn:
A custom initializer taking weight and bias as inputs.
Overrides init if not None.
"""
Args:
in_dim:
The final dimension of inputs to the layer
out_dim:
The final dimension of layer outputs
bias:
Whether to learn an additive bias. True by default
init:
The initializer to use. Choose from:
"default": LeCun fan-in truncated normal initialization
"relu": He initialization w/ truncated normal distribution
"glorot": Fan-average Glorot uniform initialization
"gating": Weights=0, Bias=1
"normal": Normal initialization with std=1/sqrt(fan_in)
"final": Weights=0, Bias=0
Overridden by init_fn if the latter is not None.
init_fn:
A custom initializer taking weight and bias as inputs.
Overrides init if not None.
"""
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
if(bias):
if bias:
with torch.no_grad():
self.bias.fill_(0)
if(init_fn is not None):
if init_fn is not None:
init_fn(self.weight, self.bias)
else:
if(init == "default"):
if init == "default":
lecun_normal_init_(self.weight)
elif(init == "relu"):
elif init == "relu":
he_normal_init_(self.weight)
elif(init == "glorot"):
elif init == "glorot":
glorot_uniform_init_(self.weight)
elif(init == "gating"):
elif init == "gating":
gating_init_(self.weight)
if(bias):
if bias:
with torch.no_grad():
self.bias.fill_(1.)
elif(init == "normal"):
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
elif(init == "final"):
elif init == "final":
final_init_(self.weight)
else:
raise ValueError("Invalid init string.")
class Attention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization.
"""
def __init__(self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
Standard multi-head attention using AlphaFold's default layer
initialization.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
"""
super(Attention, self).__init__()
......@@ -202,7 +206,7 @@ class Attention(nn.Module):
self.no_heads = no_heads
self.gating = gating
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# stated in the supplement, but the overall channel dimension
self.linear_q = Linear(
......@@ -218,28 +222,31 @@ class Attention(nn.Module):
self.c_hidden * self.no_heads, self.c_q, init="final"
)
if(self.gating is not None):
self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating")
if self.gating is not None:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
def forward(
self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data
k_x:
[*, K, C_k] key data
v_x:
[*, V, C_v] value data
Returns
[*, Q, C_q] attention update
Args:
q_x:
[*, Q, C_q] query data
k_x:
[*, K, C_k] key data
v_x:
[*, V, C_v] value data
Returns
[*, Q, C_q] attention update
"""
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
......@@ -254,11 +261,11 @@ class Attention(nn.Module):
# [*, H, Q, K]
a = torch.matmul(
permute_final_dims(q, (0, 2, 1, 3)), # [*, H, Q, C_hidden]
permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K]
permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K]
)
norm = 1 / math.sqrt(self.c_hidden) # [1]
norm = 1 / math.sqrt(self.c_hidden) # [1]
a = a * norm
if(biases is not None):
if biases is not None:
for b in biases:
a = a + b
a = self.softmax(a)
......@@ -271,18 +278,18 @@ class Attention(nn.Module):
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
if(self.gating):
if self.gating:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
......@@ -301,10 +308,16 @@ class GlobalAttention(nn.Module):
)
self.linear_k = Linear(
c_in, c_hidden, bias=False, init="glorot",
c_in,
c_hidden,
bias=False,
init="glorot",
)
self.linear_v = Linear(
c_in, c_hidden, bias=False, init="glorot",
c_in,
c_hidden,
bias=False,
init="glorot",
)
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
......@@ -314,8 +327,9 @@ class GlobalAttention(nn.Module):
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# [*, N_res, C_in]
q = (torch.sum(m * mask.unsqueeze(-1), dim=-2) /
(torch.sum(mask, dim=-1)[..., None] + self.eps))
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps
)
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
......@@ -331,7 +345,7 @@ class GlobalAttention(nn.Module):
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a = a + bias
......@@ -351,7 +365,7 @@ class GlobalAttention(nn.Module):
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(o.shape[:-2] + (-1,))
......
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -25,14 +25,14 @@ from openfold.np.residue_constants import (
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.affine_utils import T, quat_to_rot
from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames,
)
from openfold.utils.tensor_utils import (
dict_multimap,
permute_final_dims,
dict_multimap,
permute_final_dims,
flatten_final_dims,
)
......@@ -40,9 +40,9 @@ from openfold.utils.tensor_utils import (
class AngleResnetBlock(nn.Module):
def __init__(self, c_hidden):
"""
Args:
c_hidden:
Hidden channel dimension
Args:
c_hidden:
Hidden channel dimension
"""
super(AngleResnetBlock, self).__init__()
......@@ -67,21 +67,22 @@ class AngleResnetBlock(nn.Module):
class AngleResnet(nn.Module):
"""
Implements Algorithm 20, lines 11-14
Implements Algorithm 20, lines 11-14
"""
def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Hidden channel dimension
no_blocks:
Number of resnet blocks
no_angles:
Number of torsion angles to generate
epsilon:
Small constant for normalization
Args:
c_in:
Input channel dimension
c_hidden:
Hidden channel dimension
no_blocks:
Number of resnet blocks
no_angles:
Number of torsion angles to generate
epsilon:
Small constant for normalization
"""
super(AngleResnet, self).__init__()
......@@ -103,22 +104,21 @@ class AngleResnet(nn.Module):
self.relu = nn.ReLU()
def forward(self,
s: torch.Tensor,
s_initial: torch.Tensor
def forward(
self, s: torch.Tensor, s_initial: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
s:
[*, C_hidden] single embedding
s_initial:
[*, C_hidden] single embedding as of the start of the
StructureModule
Returns:
[*, no_angles, 2] predicted angles
"""
# NOTE: The ReLU's applied to the inputs are absent from the supplement
# pseudocode but present in the source. For maximal compatibility with
Args:
s:
[*, C_hidden] single embedding
s_initial:
[*, C_hidden] single embedding as of the start of the
StructureModule
Returns:
[*, no_angles, 2] predicted angles
"""
# NOTE: The ReLU's applied to the inputs are absent from the supplement
# pseudocode but present in the source. For maximal compatibility with
# the pretrained weights, I'm going with the source.
# [*, C_hidden]
......@@ -153,9 +153,11 @@ class AngleResnet(nn.Module):
class InvariantPointAttention(nn.Module):
"""
Implements Algorithm 22.
Implements Algorithm 22.
"""
def __init__(self,
def __init__(
self,
c_s,
c_z,
c_hidden,
......@@ -166,19 +168,19 @@ class InvariantPointAttention(nn.Module):
eps=1e-8,
):
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_hidden:
Hidden channel dimension
no_heads:
Number of attention heads
no_qk_points:
Number of query/key points to generate
no_v_points:
Number of value points to generate
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_hidden:
Hidden channel dimension
no_heads:
Number of attention heads
no_qk_points:
Number of query/key points to generate
no_v_points:
Number of value points to generate
"""
super(InvariantPointAttention, self).__init__()
......@@ -212,32 +214,33 @@ class InvariantPointAttention(nn.Module):
self.head_weights = nn.Parameter(torch.zeros((no_heads)))
ipa_point_weights_init_(self.head_weights)
concat_out_dim = self.no_heads * (self.c_z
+ self.c_hidden
+ self.no_v_points * 4)
concat_out_dim = self.no_heads * (
self.c_z + self.c_hidden + self.no_v_points * 4
)
self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
self.softmax = nn.Softmax(dim=-1)
self.softplus = nn.Softplus()
def forward(self,
s: torch.Tensor,
z: torch.Tensor,
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
t: T,
mask: torch.Tensor,
) -> torch.Tensor:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
t:
[*, N_res] affine transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
t:
[*, N_res] affine transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
#######################################
# Generate scalar and point activations
......@@ -261,12 +264,12 @@ class InvariantPointAttention(nn.Module):
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = t[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
)
# [*, N_res, H * (P_q + P_v) * 3]
......@@ -278,15 +281,11 @@ class InvariantPointAttention(nn.Module):
kv_pts = t[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(
kv_pts.shape[:-2] + (self.no_heads, -1, 3)
)
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(
kv_pts,
[self.no_qk_points, self.no_v_points],
dim=-2
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
)
##########################
......@@ -298,12 +297,12 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
a = torch.matmul(
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a = a * math.sqrt(1. / (3 * self.c_hidden))
a = a + (math.sqrt(1. / 3) * permute_final_dims(b, (2, 0, 1)))
a = a * math.sqrt(1.0 / (3 * self.c_hidden))
a = a + (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2
......@@ -312,12 +311,12 @@ class InvariantPointAttention(nn.Module):
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
)
head_weights = (
head_weights * math.sqrt(1. / (3 * (self.no_qk_points * 9. / 2)))
)
pt_att = pt_att * head_weights
head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
)
pt_att = pt_att * head_weights
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
# [*, N_res, N_res]
......@@ -345,10 +344,10 @@ class InvariantPointAttention(nn.Module):
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None] *
permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2
dim=-2,
)
# [*, N_res, H, P_v, 3]
......@@ -357,8 +356,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps),
2
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v, 3]
......@@ -372,26 +370,24 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s]
s = self.linear_out(
torch.cat((
o,
*torch.unbind(o_pt, dim=-1),
o_pt_norm,
o_pair
), dim=-1)
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
)
)
return s
class BackboneUpdate(nn.Module):
"""
Implements Algorithm 23.
Implements Algorithm 23.
"""
def __init__(self, c_s):
"""
Args:
c_s:
Single representation channel dimension
Args:
c_s:
Single representation channel dimension
"""
super(BackboneUpdate, self).__init__()
......@@ -401,24 +397,24 @@ class BackboneUpdate(nn.Module):
def forward(self, s):
"""
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res] affine transformation object
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res] affine transformation object
"""
# [*, 6]
params = self.linear(s)
# [*, 3]
quats, trans = params[...,:3], params[...,3:]
quats, trans = params[..., :3], params[..., 3:]
# [*]
#norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
# norm_denom = torch.sqrt(sum(torch.unbind(quats ** 2, dim=-1)) + 1)
norm_denom = torch.sqrt(torch.sum(quats ** 2, dim=-1) + 1)
# [*, 3]
ones = (
s.new_ones((1,) * len(quats.shape)).expand(quats.shape[:-1] + (1,))
# [*, 3]
ones = s.new_ones((1,) * len(quats.shape)).expand(
quats.shape[:-1] + (1,)
)
# [*, 4]
......@@ -436,7 +432,7 @@ class StructureModuleTransitionLayer(nn.Module):
super(StructureModuleTransitionLayer, self).__init__()
self.c = c
self.linear_1 = Linear(self.c, self.c, init="relu")
self.linear_2 = Linear(self.c, self.c, init="relu")
self.linear_3 = Linear(self.c, self.c, init="final")
......@@ -483,8 +479,9 @@ class StructureModuleTransition(nn.Module):
class StructureModule(nn.Module):
def __init__(self,
c_s,
def __init__(
self,
c_s,
c_z,
c_ipa,
c_resnet,
......@@ -502,39 +499,39 @@ class StructureModule(nn.Module):
**kwargs,
):
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_ipa:
IPA hidden channel dimension
c_resnet:
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
no_heads_ipa:
Number of IPA heads
no_qk_points:
Number of query/key points to generate during IPA
no_v_points:
Number of value points to generate during IPA
dropout_rate:
Dropout rate used throughout the layer
no_blocks:
Number of structure module blocks
no_transition_layers:
Number of layers in the single representation transition
(Alg. 23 lines 8-9)
no_resnet_blocks:
Number of blocks in the angle resnet
no_angles:
Number of angles to generate in the angle resnet
trans_scale_factor:
Scale of single representation transition hidden dimension
epsilon:
Small number used in angle resnet normalization
inf:
Large number used for attention masking
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_ipa:
IPA hidden channel dimension
c_resnet:
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
no_heads_ipa:
Number of IPA heads
no_qk_points:
Number of query/key points to generate during IPA
no_v_points:
Number of value points to generate during IPA
dropout_rate:
Dropout rate used throughout the layer
no_blocks:
Number of structure module blocks
no_transition_layers:
Number of layers in the single representation transition
(Alg. 23 lines 8-9)
no_resnet_blocks:
Number of blocks in the angle resnet
no_angles:
Number of angles to generate in the angle resnet
trans_scale_factor:
Scale of single representation transition hidden dimension
epsilon:
Small number used in angle resnet normalization
inf:
Large number used for attention masking
"""
super(StructureModule, self).__init__()
self.c_s = c_s
......@@ -587,33 +584,34 @@ class StructureModule(nn.Module):
self.bb_update = BackboneUpdate(self.c_s)
self.angle_resnet = AngleResnet(
self.c_s,
self.c_resnet,
self.c_s,
self.c_resnet,
self.no_resnet_blocks,
self.no_angles,
self.epsilon,
)
def forward(self,
def forward(
self,
s,
z,
f,
mask=None,
):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
f:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
f:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if(mask is None):
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
......@@ -644,7 +642,9 @@ class StructureModule(nn.Module):
unnormalized_a, a = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(
t.scale_translation(self.trans_scale_factor), a, f,
t.scale_translation(self.trans_scale_factor),
a,
f,
)
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
......@@ -653,8 +653,7 @@ class StructureModule(nn.Module):
)
preds = {
"frames":
t.scale_translation(self.trans_scale_factor).to_4x4(),
"frames": t.scale_translation(self.trans_scale_factor).to_4x4(),
"sidechain_frames": all_frames_to_global.to_4x4(),
"unnormalized_angles": unnormalized_a,
"angles": a,
......@@ -663,7 +662,7 @@ class StructureModule(nn.Module):
outputs.append(preds)
if(i < (self.no_blocks - 1)):
if i < (self.no_blocks - 1):
t = t.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs)
......@@ -672,28 +671,28 @@ class StructureModule(nn.Module):
return outputs
def _init_residue_constants(self, float_dtype, device):
if(self.default_frames is None):
if self.default_frames is None:
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
)
if(self.group_idx is None):
if self.group_idx is None:
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
device=device,
)
if(self.atom_mask is None):
if self.atom_mask is None:
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
)
if(self.lit_positions is None):
if self.lit_positions is None:
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
device=device,
)
def torsion_angles_to_frames(self, t, alpha, f):
......@@ -702,17 +701,16 @@ class StructureModule(nn.Module):
# Separated purely to make testing less annoying
return torsion_angles_to_frames(t, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(self,
t, # [*, N, 8]
f # [*, N]
def frames_and_literature_positions_to_atom14_pos(
self, t, f # [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(t.rots.dtype, t.rots.device)
return frames_and_literature_positions_to_atom14_pos(
t,
f,
self.default_frames,
t,
f,
self.default_frames,
self.group_idx,
self.atom_mask,
self.lit_positions,
self.lit_positions,
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -18,9 +18,9 @@ import math
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, Attention
from openfold.model.primitives import Linear, Attention
from openfold.utils.deepspeed import checkpoint_blocks
from openfold.model.dropout import (
from openfold.model.dropout import (
DropoutRowwise,
DropoutColumnwise,
)
......@@ -35,35 +35,28 @@ from openfold.model.triangular_multiplicative_update import (
)
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
permute_final_dims,
flatten_final_dims,
)
class TemplatePointwiseAttention(nn.Module):
"""
Implements Algorithm 17.
Implements Algorithm 17.
"""
def __init__(self,
c_t,
c_z,
c_hidden,
no_heads,
chunk_size,
inf,
**kwargs
):
def __init__(self, c_t, c_z, c_hidden, no_heads, chunk_size, inf, **kwargs):
"""
Args:
c_t:
Template embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
Args:
c_t:
Template embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super(TemplatePointwiseAttention, self).__init__()
self.c_t = c_t
self.c_z = c_z
self.c_hidden = c_hidden
......@@ -72,30 +65,33 @@ class TemplatePointwiseAttention(nn.Module):
self.inf = inf
self.mha = Attention(
self.c_z, self.c_t, self.c_t,
self.c_hidden, self.no_heads,
self.c_z,
self.c_t,
self.c_t,
self.c_hidden,
self.no_heads,
gating=False,
)
def forward(self, t, z, template_mask=None):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
z:
[*, N_res, N_res, C_t] pair embedding
template_mask:
[*, N_templ] template mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
z:
[*, N_res, N_res, C_t] pair embedding
template_mask:
[*, N_templ] template mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if(template_mask is None):
if template_mask is None:
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just
# [*, N_templ] mask from the code. I'm pretty sure it's always just
# 1, but not sure enough to remove it. It's nice to have, I guess.
template_mask = t.new_ones(t.shape[:-3])
bias = (self.inf * (template_mask[..., None, None, None, None, :] - 1))
bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
# [*, N_res, N_res, 1, C_z]
z = z.unsqueeze(-2)
......@@ -110,36 +106,37 @@ class TemplatePointwiseAttention(nn.Module):
"v_x": t,
"biases": [bias],
}
if(self.chunk_size is not None):
if self.chunk_size is not None:
z = chunk_layer(
self.mha,
mha_inputs,
chunk_size=self.chunk_size,
no_batch_dims=len(z.shape[:-2])
no_batch_dims=len(z.shape[:-2]),
)
else:
z = self.mha(**mha_inputs)
# [*, N_res, N_res, C_z]
z = z.squeeze(-2)
return z
class TemplatePairStackBlock(nn.Module):
def __init__(self,
c_t,
def __init__(
self,
c_t,
c_hidden_tri_att,
c_hidden_tri_mul,
no_heads,
pair_transition_n,
no_heads,
pair_transition_n,
dropout_rate,
chunk_size,
inf,
**kwargs,
):
super(TemplatePairStackBlock, self).__init__()
self.c_t = c_t
self.c_hidden_tri_att = c_hidden_tri_att
self.c_hidden_tri_mul = c_hidden_tri_mul
......@@ -151,11 +148,11 @@ class TemplatePairStackBlock(nn.Module):
self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate)
self.tri_att_start = TriangleAttentionStartingNode(
self.c_t,
self.c_hidden_tri_att,
self.no_heads,
self.c_t,
self.c_hidden_tri_att,
self.no_heads,
chunk_size=chunk_size,
inf=inf,
)
......@@ -188,21 +185,23 @@ class TemplatePairStackBlock(nn.Module):
z = z + self.dropout_row(self.tri_mul_out(z, mask=mask))
z = z + self.dropout_row(self.tri_mul_in(z, mask=mask))
z = z + self.pair_transition(z, mask=mask if _mask_trans else None)
return z
class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
Implements Algorithm 16.
"""
def __init__(self,
c_t,
def __init__(
self,
c_t,
c_hidden_tri_att,
c_hidden_tri_mul,
no_blocks,
no_heads,
pair_transition_n,
no_blocks,
no_heads,
pair_transition_n,
dropout_rate,
blocks_per_ckpt,
chunk_size,
......@@ -210,26 +209,26 @@ class TemplatePairStack(nn.Module):
**kwargs,
):
"""
Args:
c_t:
Template embedding channel dimension
c_hidden_tri_att:
Per-head hidden dimension for triangular attention
c_hidden_tri_att:
Hidden dimension for triangular multiplication
no_blocks:
Number of blocks in the stack
pair_transition_n:
Scale of pair transition (Alg. 15) hidden dimension
dropout_rate:
Dropout rate used throughout the stack
blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables
activation checkpointing
chunk_size:
Size of subbatches. A higher value increases throughput at
the cost of memory
"""
Args:
c_t:
Template embedding channel dimension
c_hidden_tri_att:
Per-head hidden dimension for triangular attention
c_hidden_tri_att:
Hidden dimension for triangular multiplication
no_blocks:
Number of blocks in the stack
pair_transition_n:
Scale of pair transition (Alg. 15) hidden dimension
dropout_rate:
Dropout rate used throughout the stack
blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables
activation checkpointing
chunk_size:
Size of subbatches. A higher value increases throughput at
the cost of memory
"""
super(TemplatePairStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
......@@ -250,28 +249,30 @@ class TemplatePairStack(nn.Module):
self.layer_norm = nn.LayerNorm(c_t)
def forward(self,
def forward(
self,
t: torch.tensor,
mask: torch.tensor,
_mask_trans: bool = True,
):
"""
Args:
t:
[*, N_res, N_res, C_t] template embedding
mask:
[*, N_res, N_res] mask
Returns:
[*, N_res, N_res, C_t] template embedding update
Args:
t:
[*, N_res, N_res, C_t] template embedding
mask:
[*, N_res, N_res] mask
Returns:
[*, N_res, N_res, C_t] template embedding update
"""
t, = checkpoint_blocks(
(t,) = checkpoint_blocks(
blocks=[
partial(
b,
mask=mask,
b,
mask=mask,
_mask_trans=_mask_trans,
) for b in self.blocks
],
)
for b in self.blocks
],
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
......
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -20,29 +20,24 @@ import torch.nn as nn
from openfold.model.primitives import Linear, Attention
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
class TriangleAttention(nn.Module):
def __init__(self,
c_in,
c_hidden,
no_heads,
starting,
chunk_size=4,
inf=1e9
def __init__(
self, c_in, c_hidden, no_heads, starting, chunk_size=4, inf=1e9
):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Overall hidden channel dimension (not per-head)
no_heads:
Number of attention heads
Args:
c_in:
Input channel dimension
c_hidden:
Overall hidden channel dimension (not per-head)
no_heads:
Number of attention heads
"""
super(TriangleAttention, self).__init__()
......@@ -54,40 +49,38 @@ class TriangleAttention(nn.Module):
self.inf = inf
self.layer_norm = nn.LayerNorm(self.c_in)
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
self.mha = Attention(
self.c_in, self.c_in, self.c_in,
self.c_hidden,
self.no_heads
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
)
def forward(self, x, mask=None):
"""
Args:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
Args:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
"""
if(mask is None):
if mask is None:
# [*, I, J]
mask = x.new_ones(
x.shape[:-1],
x.shape[:-1],
)
# Shape annotations assume self.starting. Else, I and J are flipped
if(not self.starting):
if not self.starting:
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)
# [*, I, J, C_in]
x = self.layer_norm(x)
# [*, I, 1, 1, J]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
......@@ -100,17 +93,17 @@ class TriangleAttention(nn.Module):
"v_x": x,
"biases": [mask_bias, triangle_bias],
}
if(self.chunk_size is not None):
if self.chunk_size is not None:
x = chunk_layer(
self.mha,
mha_inputs,
mha_inputs,
chunk_size=self.chunk_size,
no_batch_dims=len(x.shape[:-2])
no_batch_dims=len(x.shape[:-2]),
)
else:
x = self.mha(**mha_inputs)
if(not self.starting):
if not self.starting:
x = x.transpose(-2, -3)
return x
......@@ -118,13 +111,15 @@ class TriangleAttention(nn.Module):
class TriangleAttentionStartingNode(TriangleAttention):
"""
Implements Algorithm 13.
Implements Algorithm 13.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
class TriangleAttentionEndingNode(TriangleAttention):
"""
Implements Algorithm 14.
Implements Algorithm 14.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -23,16 +23,17 @@ from openfold.utils.tensor_utils import permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module):
"""
Implements Algorithms 11 and 12.
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
......@@ -53,22 +54,24 @@ class TriangleMultiplicativeUpdate(nn.Module):
cp = self._outgoing_matmul if self._outgoing else self._incoming_matmul
self.combine_projections = cp
def _outgoing_matmul(self,
a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C]
def _outgoing_matmul(
self,
a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, (2, 1, 0)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
def _incoming_matmul(self,
a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C]
def _incoming_matmul(
self,
a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
......@@ -76,21 +79,21 @@ class TriangleMultiplicativeUpdate(nn.Module):
permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, (2, 0, 1)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
def forward(self, z, mask=None):
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if(mask is None):
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
......@@ -111,17 +114,21 @@ class TriangleMultiplicativeUpdate(nn.Module):
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
Implements Algorithm 11.
"""
__init__ = partialmethod(
TriangleMultiplicativeUpdate.__init__, _outgoing=True,
TriangleMultiplicativeUpdate.__init__,
_outgoing=True,
)
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
Implements Algorithm 12.
"""
__init__ = partialmethod(
TriangleMultiplicativeUpdate.__init__, _outgoing=False,
TriangleMultiplicativeUpdate.__init__,
_outgoing=False,
)
......@@ -3,12 +3,14 @@ import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [os.path.basename(f)[:-3] for f in _files if os.path.isfile(f) and not f.endswith("__init__.py")]
_modules = [(m, importlib.import_module('.' + m, __name__)) for m in __all__]
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -27,204 +27,220 @@ ModelOutput = Mapping[str, Any] # Is a nested dict.
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure('none', pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f'Only single model PDBs are supported. Found {len(models)} models.')
model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError(
'Only single chain PDBs are supported when chain_id not specified. '
f'Found {len(chains)} chains.')
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure("none", pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f"Only single model PDBs are supported. Found {len(models)} models."
)
model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chain = chains[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
b_factors = []
for res in chain:
if res.id[2] != ' ':
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
b_factors=np.array(b_factors))
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError(
"Only single chain PDBs are supported when chain_id not specified. "
f"Found {len(chains)} chains."
)
else:
chain = chains[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
b_factors = []
for res in chain:
if res.id[2] != " ":
raise ValueError(
f"PDB contains an insertion code at chain {chain.id} and residue "
f"index {res.id[1]}. These are not supported."
)
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num
)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[
residue_constants.atom_order[atom.name]
] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
b_factors=np.array(b_factors),
)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ['X']
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError('Invalid aatypes.')
pdb_lines.append('MODEL 1')
atom_index = 1
chain_id = 'A'
# Add all atom sites.
for i in range(aatype.shape[0]):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
if mask < 0.5:
continue
record_type = 'ATOM'
name = atom_name if len(atom_name) == 4 else f' {atom_name}'
alt_loc = ''
insertion_code = ''
occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works.
charge = ''
# PDB is a columnar format, every space matters here!
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
f'{res_name_3:>3} {chain_id:>1}'
f'{residue_index[i]:>4}{insertion_code:>1} '
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
f'{occupancy:>6.2f}{b_factor:>6.2f} '
f'{element:>2}{charge:>2}')
pdb_lines.append(atom_line)
atom_index += 1
# Close the chain.
chain_end = 'TER'
chain_termination_line = (
f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} '
f'{chain_id:>1}{residue_index[-1]:>4}')
pdb_lines.append(chain_termination_line)
pdb_lines.append('ENDMDL')
pdb_lines.append('END')
pdb_lines.append('')
return '\n'.join(pdb_lines)
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
pdb_lines.append("MODEL 1")
atom_index = 1
chain_id = "A"
# Add all atom sites.
for i in range(aatype.shape[0]):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[
0
] # Protein supports only C, N, O, S, this works.
charge = ""
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_id:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line)
atom_index += 1
# Close the chain.
chain_end = "TER"
chain_termination_line = (
f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} "
f"{chain_id:>1}{residue_index[-1]:>4}"
)
pdb_lines.append(chain_termination_line)
pdb_lines.append("ENDMDL")
pdb_lines.append("END")
pdb_lines.append("")
return "\n".join(pdb_lines)
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(features: FeatureDict, result: ModelOutput,
b_factors: Optional[np.ndarray] = None) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
Returns:
A protein instance.
"""
if b_factors is None:
b_factors = np.zeros_like(result['final_atom_mask'])
return Protein(
aatype=features['aatype'],
atom_positions=result['final_atom_positions'],
atom_mask=result['final_atom_mask'],
residue_index=features['residue_index'] + 1,
b_factors=b_factors
)
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
Returns:
A protein instance.
"""
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=features["aatype"],
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1,
b_factors=b_factors,
)
......@@ -3,13 +3,14 @@ import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [os.path.basename(f)[:-3] for f in _files if os.path.isfile(f) and not f.endswith("__init__.py")]
_modules = [(m, importlib.import_module('.' + m, __name__)) for m in __all__]
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -38,12 +38,12 @@ LENGTH = unit.angstroms
def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:
"""Returns True if the atom will be restrained by the given restraint set."""
"""Returns True if the atom will be restrained by the given restraint set."""
if rset == "non_hydrogen":
return atom.element.name != "hydrogen"
elif rset == "c_alpha":
return atom.name == "CA"
if rset == "non_hydrogen":
return atom.element.name != "hydrogen"
elif rset == "c_alpha":
return atom.name == "CA"
def _add_restraints(
......@@ -51,24 +51,29 @@ def _add_restraints(
reference_pdb: openmm_app.PDBFile,
stiffness: unit.Unit,
rset: str,
exclude_residues: Sequence[int]):
"""Adds a harmonic potential that restrains the system to a structure."""
assert rset in ["non_hydrogen", "c_alpha"]
force = openmm.CustomExternalForce(
"0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
force.addGlobalParameter("k", stiffness)
for p in ["x0", "y0", "z0"]:
force.addPerParticleParameter(p)
for i, atom in enumerate(reference_pdb.topology.atoms()):
if atom.residue.index in exclude_residues:
continue
if will_restrain(atom, rset):
force.addParticle(i, reference_pdb.positions[i])
logging.info("Restraining %d / %d particles.",
force.getNumParticles(), system.getNumParticles())
system.addForce(force)
exclude_residues: Sequence[int],
):
"""Adds a harmonic potential that restrains the system to a structure."""
assert rset in ["non_hydrogen", "c_alpha"]
force = openmm.CustomExternalForce(
"0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)"
)
force.addGlobalParameter("k", stiffness)
for p in ["x0", "y0", "z0"]:
force.addPerParticleParameter(p)
for i, atom in enumerate(reference_pdb.topology.atoms()):
if atom.residue.index in exclude_residues:
continue
if will_restrain(atom, rset):
force.addParticle(i, reference_pdb.positions[i])
logging.info(
"Restraining %d / %d particles.",
force.getNumParticles(),
system.getNumParticles(),
)
system.addForce(force)
def _openmm_minimize(
......@@ -77,291 +82,324 @@ def _openmm_minimize(
tolerance: unit.Unit,
stiffness: unit.Unit,
restraint_set: str,
exclude_residues: Sequence[int]):
"""Minimize energy via openmm."""
pdb_file = io.StringIO(pdb_str)
pdb = openmm_app.PDBFile(pdb_file)
force_field = openmm_app.ForceField("amber99sb.xml")
constraints = openmm_app.HBonds
system = force_field.createSystem(
pdb.topology, constraints=constraints)
if stiffness > 0 * ENERGY / (LENGTH**2):
_add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
platform = openmm.Platform.getPlatformByName("CPU")
simulation = openmm_app.Simulation(
pdb.topology, system, integrator, platform)
simulation.context.setPositions(pdb.positions)
ret = {}
state = simulation.context.getState(getEnergy=True, getPositions=True)
ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
simulation.minimizeEnergy(maxIterations=max_iterations,
tolerance=tolerance)
state = simulation.context.getState(getEnergy=True, getPositions=True)
ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
return ret
exclude_residues: Sequence[int],
):
"""Minimize energy via openmm."""
pdb_file = io.StringIO(pdb_str)
pdb = openmm_app.PDBFile(pdb_file)
force_field = openmm_app.ForceField("amber99sb.xml")
constraints = openmm_app.HBonds
system = force_field.createSystem(pdb.topology, constraints=constraints)
if stiffness > 0 * ENERGY / (LENGTH ** 2):
_add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
platform = openmm.Platform.getPlatformByName("CPU")
simulation = openmm_app.Simulation(
pdb.topology, system, integrator, platform
)
simulation.context.setPositions(pdb.positions)
ret = {}
state = simulation.context.getState(getEnergy=True, getPositions=True)
ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
simulation.minimizeEnergy(maxIterations=max_iterations, tolerance=tolerance)
state = simulation.context.getState(getEnergy=True, getPositions=True)
ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
return ret
def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):
"""Returns a pdb string provided OpenMM topology and positions."""
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, positions, f)
return f.getvalue()
"""Returns a pdb string provided OpenMM topology and positions."""
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, positions, f)
return f.getvalue()
def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):
"""Checks that no atom positions have been altered by cleaning."""
cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))
cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))
for ref_res, cl_res in zip(reference.topology.residues(),
cleaned.topology.residues()):
assert ref_res.name == cl_res.name
for rat in ref_res.atoms():
for cat in cl_res.atoms():
if cat.name == rat.name:
if not np.array_equal(cl_xyz[cat.index], ref_xyz[rat.index]):
raise ValueError(f"Coordinates of cleaned atom {cat} do not match "
f"coordinates of reference atom {rat}.")
"""Checks that no atom positions have been altered by cleaning."""
cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))
cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))
for ref_res, cl_res in zip(
reference.topology.residues(), cleaned.topology.residues()
):
assert ref_res.name == cl_res.name
for rat in ref_res.atoms():
for cat in cl_res.atoms():
if cat.name == rat.name:
if not np.array_equal(
cl_xyz[cat.index], ref_xyz[rat.index]
):
raise ValueError(
f"Coordinates of cleaned atom {cat} do not match "
f"coordinates of reference atom {rat}."
)
def _check_residues_are_well_defined(prot: protein.Protein):
"""Checks that all residues contain non-empty atom sets."""
if (prot.atom_mask.sum(axis=-1) == 0).any():
raise ValueError("Amber minimization can only be performed on proteins with"
" well-defined residues. This protein contains at least"
" one residue with no atoms.")
"""Checks that all residues contain non-empty atom sets."""
if (prot.atom_mask.sum(axis=-1) == 0).any():
raise ValueError(
"Amber minimization can only be performed on proteins with"
" well-defined residues. This protein contains at least"
" one residue with no atoms."
)
def _check_atom_mask_is_ideal(prot):
"""Sanity-check the atom mask is ideal, up to a possible OXT."""
atom_mask = prot.atom_mask
ideal_atom_mask = protein.ideal_atom_mask(prot)
utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)
"""Sanity-check the atom mask is ideal, up to a possible OXT."""
atom_mask = prot.atom_mask
ideal_atom_mask = protein.ideal_atom_mask(prot)
utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)
def clean_protein(
prot: protein.Protein,
checks: bool = True):
"""Adds missing atoms to Protein instance.
Args:
prot: A `protein.Protein` instance.
checks: A `bool` specifying whether to add additional checks to the cleaning
process.
Returns:
pdb_string: A string of the cleaned protein.
"""
_check_atom_mask_is_ideal(prot)
# Clean pdb.
prot_pdb_string = protein.to_pdb(prot)
pdb_file = io.StringIO(prot_pdb_string)
alterations_info = {}
fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
fixed_pdb_file = io.StringIO(fixed_pdb)
pdb_structure = PdbStructure(fixed_pdb_file)
cleanup.clean_structure(pdb_structure, alterations_info)
logging.info("alterations info: %s", alterations_info)
# Write pdb file of cleaned structure.
as_file = openmm_app.PDBFile(pdb_structure)
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
if checks:
_check_cleaned_atoms(pdb_string, prot_pdb_string)
return pdb_string
def clean_protein(prot: protein.Protein, checks: bool = True):
"""Adds missing atoms to Protein instance.
Args:
prot: A `protein.Protein` instance.
checks: A `bool` specifying whether to add additional checks to the cleaning
process.
Returns:
pdb_string: A string of the cleaned protein.
"""
_check_atom_mask_is_ideal(prot)
# Clean pdb.
prot_pdb_string = protein.to_pdb(prot)
pdb_file = io.StringIO(prot_pdb_string)
alterations_info = {}
fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
fixed_pdb_file = io.StringIO(fixed_pdb)
pdb_structure = PdbStructure(fixed_pdb_file)
cleanup.clean_structure(pdb_structure, alterations_info)
logging.info("alterations info: %s", alterations_info)
# Write pdb file of cleaned structure.
as_file = openmm_app.PDBFile(pdb_structure)
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
if checks:
_check_cleaned_atoms(pdb_string, prot_pdb_string)
return pdb_string
def make_atom14_positions(prot):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
])
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
# Add dummy mapping for restype 'UNK'.
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
# Create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein.
residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
residx_atom14_mask = restype_atom14_mask[prot["aatype"]]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
prot["all_atom_mask"], residx_atom14_to_atom37, axis=1).astype(np.float32)
# Gather the ground truth positions.
residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
np.take_along_axis(prot["all_atom_positions"],
residx_atom14_to_atom37[..., None],
axis=1))
prot["atom14_atom_exists"] = residx_atom14_mask
prot["atom14_gt_exists"] = residx_atom14_gt_mask
prot["atom14_gt_positions"] = residx_atom14_gt_positions
prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64)
# Create the gather indices for mapping back.
residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64)
# Create the corresponding mask.
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
prot["atom37_atom_exists"] = residx_atom37_mask
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [
residue_constants.restype_1to3[res] for res in residue_constants.restypes
]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = residue_constants.restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = residue_constants.restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform = renaming_matrices[prot["aatype"]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = np.einsum("rac,rab->rbc",
residx_atom14_gt_positions,
renaming_transform)
prot["atom14_alt_gt_positions"] = alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask = np.einsum("ra,rab->rb",
residx_atom14_gt_mask,
renaming_transform)
prot["atom14_alt_gt_exists"] = alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = residue_constants.restype_order[
residue_constants.restype_3to1[resname]]
atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index(
atom_name1)
atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index(
atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence.
prot["atom14_atom_is_ambiguous"] = (
restype_atom14_is_ambiguous[prot["aatype"]])
return prot
"""Constructs denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]
]
restype_atom14_to_atom37.append(
[
(residue_constants.atom_order[name] if name else 0)
for name in atom_names
]
)
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append(
[
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
]
)
restype_atom14_mask.append(
[(1.0 if name else 0.0) for name in atom_names]
)
# Add dummy mapping for restype 'UNK'.
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = np.array(
restype_atom14_to_atom37, dtype=np.int32
)
restype_atom37_to_atom14 = np.array(
restype_atom37_to_atom14, dtype=np.int32
)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
# Create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein.
residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
residx_atom14_mask = restype_atom14_mask[prot["aatype"]]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
prot["all_atom_mask"], residx_atom14_to_atom37, axis=1
).astype(np.float32)
# Gather the ground truth positions.
residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
np.take_along_axis(
prot["all_atom_positions"],
residx_atom14_to_atom37[..., None],
axis=1,
)
)
prot["atom14_atom_exists"] = residx_atom14_mask
prot["atom14_gt_exists"] = residx_atom14_gt_mask
prot["atom14_gt_positions"] = residx_atom14_gt_positions
prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64)
# Create the gather indices for mapping back.
residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64)
# Create the corresponding mask.
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
prot["atom37_atom_exists"] = residx_atom37_mask
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [
residue_constants.restype_1to3[res]
for res in residue_constants.restypes
]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = residue_constants.restype_name_to_atom14_names[
resname
].index(source_atom_swap)
target_index = residue_constants.restype_name_to_atom14_names[
resname
].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack(
[all_matrices[restype] for restype in restype_3]
)
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform = renaming_matrices[prot["aatype"]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = np.einsum(
"rac,rab->rbc", residx_atom14_gt_positions, renaming_transform
)
prot["atom14_alt_gt_positions"] = alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask = np.einsum(
"ra,rab->rb", residx_atom14_gt_mask, renaming_transform
)
prot["atom14_alt_gt_exists"] = alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = residue_constants.restype_order[
residue_constants.restype_3to1[resname]
]
atom_idx1 = residue_constants.restype_name_to_atom14_names[
resname
].index(atom_name1)
atom_idx2 = residue_constants.restype_name_to_atom14_names[
resname
].index(atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence.
prot["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
prot["aatype"]
]
return prot
def find_violations(prot_np: protein.Protein):
"""Analyzes a protein and returns structural violation information.
Args:
prot_np: A protein.
Returns:
violations: A `dict` of structure components with structural violations.
violation_metrics: A `dict` of violation metrics.
"""
batch = {
"aatype": prot_np.aatype,
"all_atom_positions": prot_np.atom_positions.astype(np.float32),
"all_atom_mask": prot_np.atom_mask.astype(np.float32),
"residue_index": prot_np.residue_index,
}
batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
batch = make_atom14_positions(batch)
violations = loss.find_structural_violations_np(
batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"],
config=ml_collections.ConfigDict(
{"violation_tolerance_factor": 12, # Taken from model config.
"clash_overlap_tolerance": 1.5, # Taken from model config.
}))
violation_metrics = loss.compute_violation_metrics_np(
batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"],
violations=violations,
)
return violations, violation_metrics
"""Analyzes a protein and returns structural violation information.
Args:
prot_np: A protein.
Returns:
violations: A `dict` of structure components with structural violations.
violation_metrics: A `dict` of violation metrics.
"""
batch = {
"aatype": prot_np.aatype,
"all_atom_positions": prot_np.atom_positions.astype(np.float32),
"all_atom_mask": prot_np.atom_mask.astype(np.float32),
"residue_index": prot_np.residue_index,
}
batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
batch = make_atom14_positions(batch)
violations = loss.find_structural_violations_np(
batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"],
config=ml_collections.ConfigDict(
{
"violation_tolerance_factor": 12, # Taken from model config.
"clash_overlap_tolerance": 1.5, # Taken from model config.
}
),
)
violation_metrics = loss.compute_violation_metrics_np(
batch=batch,
atom14_pred_positions=batch["atom14_gt_positions"],
violations=violations,
)
return violations, violation_metrics
def get_violation_metrics(prot: protein.Protein):
"""Computes violation and alignment metrics."""
structural_violations, struct_metrics = find_violations(prot)
violation_idx = np.flatnonzero(
structural_violations["total_per_residue_violations_mask"])
"""Computes violation and alignment metrics."""
structural_violations, struct_metrics = find_violations(prot)
violation_idx = np.flatnonzero(
structural_violations["total_per_residue_violations_mask"]
)
struct_metrics["residue_violations"] = violation_idx
struct_metrics["num_residue_violations"] = len(violation_idx)
struct_metrics["structural_violations"] = structural_violations
return struct_metrics
struct_metrics["residue_violations"] = violation_idx
struct_metrics["num_residue_violations"] = len(violation_idx)
struct_metrics["structural_violations"] = structural_violations
return struct_metrics
def _run_one_iteration(
......@@ -372,51 +410,56 @@ def _run_one_iteration(
stiffness: float,
restraint_set: str,
max_attempts: int,
exclude_residues: Optional[Collection[int]] = None):
"""Runs the minimization pipeline.
Args:
pdb_string: A pdb string.
max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A `dict` of minimization info.
"""
exclude_residues = exclude_residues or []
# Assign physical dimensions.
tolerance = tolerance * ENERGY
stiffness = stiffness * ENERGY / (LENGTH**2)
start = time.time()
minimized = False
attempts = 0
while not minimized and attempts < max_attempts:
attempts += 1
try:
logging.info("Minimizing protein, attempt %d of %d.",
attempts, max_attempts)
ret = _openmm_minimize(
pdb_string, max_iterations=max_iterations,
tolerance=tolerance, stiffness=stiffness,
restraint_set=restraint_set,
exclude_residues=exclude_residues)
minimized = True
except Exception as e: # pylint: disable=broad-except
logging.info(e)
if not minimized:
raise ValueError(f"Minimization failed after {max_attempts} attempts.")
ret["opt_time"] = time.time() - start
ret["min_attempts"] = attempts
return ret
exclude_residues: Optional[Collection[int]] = None,
):
"""Runs the minimization pipeline.
Args:
pdb_string: A pdb string.
max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A `dict` of minimization info.
"""
exclude_residues = exclude_residues or []
# Assign physical dimensions.
tolerance = tolerance * ENERGY
stiffness = stiffness * ENERGY / (LENGTH ** 2)
start = time.time()
minimized = False
attempts = 0
while not minimized and attempts < max_attempts:
attempts += 1
try:
logging.info(
"Minimizing protein, attempt %d of %d.", attempts, max_attempts
)
ret = _openmm_minimize(
pdb_string,
max_iterations=max_iterations,
tolerance=tolerance,
stiffness=stiffness,
restraint_set=restraint_set,
exclude_residues=exclude_residues,
)
minimized = True
except Exception as e: # pylint: disable=broad-except
logging.info(e)
if not minimized:
raise ValueError(f"Minimization failed after {max_attempts} attempts.")
ret["opt_time"] = time.time() - start
ret["min_attempts"] = attempts
return ret
def run_pipeline(
......@@ -429,116 +472,134 @@ def run_pipeline(
restraint_set: str = "non_hydrogen",
max_attempts: int = 100,
checks: bool = True,
exclude_residues: Optional[Sequence[int]] = None):
"""Run iterative amber relax.
Successive relax iterations are performed until all violations have been
resolved. Each iteration involves a restrained Amber minimization, with
restraint exclusions determined by violation-participating residues.
Args:
prot: A protein to be relaxed.
stiffness: kcal/mol A**2, the restraint stiffness.
max_outer_iterations: The maximum number of iterative minimization.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
prior to every minimization.
max_iterations: An `int` specifying the maximum number of L-BFGS steps
per relax iteration. A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
The default value is the OpenMM default.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts per iteration.
checks: Whether to perform cleaning checks.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
out: A dictionary of output values.
"""
# `protein.to_pdb` will strip any poorly-defined residues so we need to
# perform this check before `clean_protein`.
_check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks)
exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues)
violations = np.inf
iteration = 0
while violations > 0 and iteration < max_outer_iterations:
ret = _run_one_iteration(
pdb_string=pdb_string,
exclude_residues=exclude_residues,
max_iterations=max_iterations,
tolerance=tolerance,
stiffness=stiffness,
restraint_set=restraint_set,
max_attempts=max_attempts)
prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True)
else:
pdb_string = ret["min_pdb"]
ret.update(get_violation_metrics(prot))
ret.update({
"num_exclusions": len(exclude_residues),
"iteration": iteration,
})
violations = ret["violations_per_residue"]
exclude_residues = exclude_residues.union(ret["residue_violations"])
logging.info("Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
"num residue violations %d num residue exclusions %d ",
ret["einit"], ret["efinal"], ret["opt_time"],
ret["num_residue_violations"], ret["num_exclusions"])
iteration += 1
return ret
def get_initial_energies(pdb_strs: Sequence[str],
stiffness: float = 0.0,
restraint_set: str = "non_hydrogen",
exclude_residues: Optional[Sequence[int]] = None):
"""Returns initial potential energies for a sequence of PDBs.
Assumes the input PDBs are ready for minimization, and all have the same
topology.
Allows time to be saved by not pdbfixing / rebuilding the system.
Args:
pdb_strs: List of PDB strings.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: Which atom types to restrain.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A list of initial energies in the same order as pdb_strs.
"""
exclude_residues = exclude_residues or []
openmm_pdbs = [openmm_app.PDBFile(PdbStructure(io.StringIO(p)))
for p in pdb_strs]
force_field = openmm_app.ForceField("amber99sb.xml")
system = force_field.createSystem(openmm_pdbs[0].topology,
constraints=openmm_app.HBonds)
stiffness = stiffness * ENERGY / (LENGTH**2)
if stiffness > 0 * ENERGY / (LENGTH**2):
_add_restraints(system, openmm_pdbs[0], stiffness, restraint_set,
exclude_residues)
simulation = openmm_app.Simulation(openmm_pdbs[0].topology,
system,
openmm.LangevinIntegrator(0, 0.01, 0.0),
openmm.Platform.getPlatformByName("CPU"))
energies = []
for pdb in openmm_pdbs:
try:
simulation.context.setPositions(pdb.positions)
state = simulation.context.getState(getEnergy=True)
energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
except Exception as e: # pylint: disable=broad-except
logging.error("Error getting initial energy, returning large value %s", e)
energies.append(unit.Quantity(1e20, ENERGY))
return energies
exclude_residues: Optional[Sequence[int]] = None,
):
"""Run iterative amber relax.
Successive relax iterations are performed until all violations have been
resolved. Each iteration involves a restrained Amber minimization, with
restraint exclusions determined by violation-participating residues.
Args:
prot: A protein to be relaxed.
stiffness: kcal/mol A**2, the restraint stiffness.
max_outer_iterations: The maximum number of iterative minimization.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
prior to every minimization.
max_iterations: An `int` specifying the maximum number of L-BFGS steps
per relax iteration. A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
The default value is the OpenMM default.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts per iteration.
checks: Whether to perform cleaning checks.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
out: A dictionary of output values.
"""
# `protein.to_pdb` will strip any poorly-defined residues so we need to
# perform this check before `clean_protein`.
_check_residues_are_well_defined(prot)
pdb_string = clean_protein(prot, checks=checks)
exclude_residues = exclude_residues or []
exclude_residues = set(exclude_residues)
violations = np.inf
iteration = 0
while violations > 0 and iteration < max_outer_iterations:
ret = _run_one_iteration(
pdb_string=pdb_string,
exclude_residues=exclude_residues,
max_iterations=max_iterations,
tolerance=tolerance,
stiffness=stiffness,
restraint_set=restraint_set,
max_attempts=max_attempts,
)
prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True)
else:
pdb_string = ret["min_pdb"]
ret.update(get_violation_metrics(prot))
ret.update(
{
"num_exclusions": len(exclude_residues),
"iteration": iteration,
}
)
violations = ret["violations_per_residue"]
exclude_residues = exclude_residues.union(ret["residue_violations"])
logging.info(
"Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
"num residue violations %d num residue exclusions %d ",
ret["einit"],
ret["efinal"],
ret["opt_time"],
ret["num_residue_violations"],
ret["num_exclusions"],
)
iteration += 1
return ret
def get_initial_energies(
pdb_strs: Sequence[str],
stiffness: float = 0.0,
restraint_set: str = "non_hydrogen",
exclude_residues: Optional[Sequence[int]] = None,
):
"""Returns initial potential energies for a sequence of PDBs.
Assumes the input PDBs are ready for minimization, and all have the same
topology.
Allows time to be saved by not pdbfixing / rebuilding the system.
Args:
pdb_strs: List of PDB strings.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: Which atom types to restrain.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A list of initial energies in the same order as pdb_strs.
"""
exclude_residues = exclude_residues or []
openmm_pdbs = [
openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs
]
force_field = openmm_app.ForceField("amber99sb.xml")
system = force_field.createSystem(
openmm_pdbs[0].topology, constraints=openmm_app.HBonds
)
stiffness = stiffness * ENERGY / (LENGTH ** 2)
if stiffness > 0 * ENERGY / (LENGTH ** 2):
_add_restraints(
system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues
)
simulation = openmm_app.Simulation(
openmm_pdbs[0].topology,
system,
openmm.LangevinIntegrator(0, 0.01, 0.0),
openmm.Platform.getPlatformByName("CPU"),
)
energies = []
for pdb in openmm_pdbs:
try:
simulation.context.setPositions(pdb.positions)
state = simulation.context.getState(getEnergy=True)
energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
except Exception as e: # pylint: disable=broad-except
logging.error(
"Error getting initial energy, returning large value %s", e
)
energies.append(unit.Quantity(1e20, ENERGY))
return energies
......@@ -25,103 +25,107 @@ from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info):
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
1) Replaces nonstandard residues.
2) Removes heterogens (non protein residues) including water.
3) Adds missing residues and missing atoms within existing residues.
4) Adds hydrogens assuming pH=7.0.
5) KeepIds is currently true, so the fixer must keep the existing chain and
residue identifiers. This will fail for some files in wider PDB that have
invalid IDs.
Args:
pdbfile: Input PDB file handle.
alterations_info: A dict that will store details of changes made.
Returns:
A PDB string representing the fixed structure.
"""
fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
fixer.findNonstandardResidues()
alterations_info['nonstandard_residues'] = fixer.nonstandardResidues
fixer.replaceNonstandardResidues()
_remove_heterogens(fixer, alterations_info, keep_water=False)
fixer.findMissingResidues()
alterations_info['missing_residues'] = fixer.missingResidues
fixer.findMissingAtoms()
alterations_info['missing_heavy_atoms'] = fixer.missingAtoms
alterations_info['missing_terminals'] = fixer.missingTerminals
fixer.addMissingAtoms(seed=0)
fixer.addMissingHydrogens()
out_handle = io.StringIO()
app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle,
keepIds=True)
return out_handle.getvalue()
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
1) Replaces nonstandard residues.
2) Removes heterogens (non protein residues) including water.
3) Adds missing residues and missing atoms within existing residues.
4) Adds hydrogens assuming pH=7.0.
5) KeepIds is currently true, so the fixer must keep the existing chain and
residue identifiers. This will fail for some files in wider PDB that have
invalid IDs.
Args:
pdbfile: Input PDB file handle.
alterations_info: A dict that will store details of changes made.
Returns:
A PDB string representing the fixed structure.
"""
fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
fixer.findNonstandardResidues()
alterations_info["nonstandard_residues"] = fixer.nonstandardResidues
fixer.replaceNonstandardResidues()
_remove_heterogens(fixer, alterations_info, keep_water=False)
fixer.findMissingResidues()
alterations_info["missing_residues"] = fixer.missingResidues
fixer.findMissingAtoms()
alterations_info["missing_heavy_atoms"] = fixer.missingAtoms
alterations_info["missing_terminals"] = fixer.missingTerminals
fixer.addMissingAtoms(seed=0)
fixer.addMissingHydrogens()
out_handle = io.StringIO()
app.PDBFile.writeFile(
fixer.topology, fixer.positions, out_handle, keepIds=True
)
return out_handle.getvalue()
def clean_structure(pdb_structure, alterations_info):
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
Args:
pdb_structure: An OpenMM structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
_replace_met_se(pdb_structure, alterations_info)
_remove_chains_of_length_one(pdb_structure, alterations_info)
Args:
pdb_structure: An OpenMM structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
_replace_met_se(pdb_structure, alterations_info)
_remove_chains_of_length_one(pdb_structure, alterations_info)
def _remove_heterogens(fixer, alterations_info, keep_water):
"""Removes the residues that Pdbfixer considers to be heterogens.
Args:
fixer: A Pdbfixer instance.
alterations_info: A dict that will store details of changes made.
keep_water: If True, water (HOH) is not considered to be a heterogen.
"""
initial_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
initial_resnames.add(residue.name)
fixer.removeHeterogens(keepWater=keep_water)
final_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
final_resnames.add(residue.name)
alterations_info['removed_heterogens'] = (
initial_resnames.difference(final_resnames))
"""Removes the residues that Pdbfixer considers to be heterogens.
Args:
fixer: A Pdbfixer instance.
alterations_info: A dict that will store details of changes made.
keep_water: If True, water (HOH) is not considered to be a heterogen.
"""
initial_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
initial_resnames.add(residue.name)
fixer.removeHeterogens(keepWater=keep_water)
final_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
final_resnames.add(residue.name)
alterations_info["removed_heterogens"] = initial_resnames.difference(
final_resnames
)
def _replace_met_se(pdb_structure, alterations_info):
"""Replace the Se in any MET residues that were not marked as modified."""
modified_met_residues = []
for res in pdb_structure.iter_residues():
name = res.get_name_with_spaces().strip()
if name == 'MET':
s_atom = res.get_atom('SD')
if s_atom.element_symbol == 'Se':
s_atom.element_symbol = 'S'
s_atom.element = element.get_by_symbol('S')
modified_met_residues.append(s_atom.residue_number)
alterations_info['Se_in_MET'] = modified_met_residues
"""Replace the Se in any MET residues that were not marked as modified."""
modified_met_residues = []
for res in pdb_structure.iter_residues():
name = res.get_name_with_spaces().strip()
if name == "MET":
s_atom = res.get_atom("SD")
if s_atom.element_symbol == "Se":
s_atom.element_symbol = "S"
s_atom.element = element.get_by_symbol("S")
modified_met_residues.append(s_atom.residue_number)
alterations_info["Se_in_MET"] = modified_met_residues
def _remove_chains_of_length_one(pdb_structure, alterations_info):
"""Removes chains that correspond to a single amino acid.
A single amino acid in a chain is both N and C terminus. There is no force
template for this case.
Args:
pdb_structure: An OpenMM pdb_structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
removed_chains = {}
for model in pdb_structure.iter_models():
valid_chains = [c for c in model.iter_chains() if len(c) > 1]
invalid_chain_ids = [c.chain_id for c in model.iter_chains() if len(c) <= 1]
model.chains = valid_chains
for chain_id in invalid_chain_ids:
model.chains_by_id.pop(chain_id)
removed_chains[model.number] = invalid_chain_ids
alterations_info['removed_chains'] = removed_chains
"""Removes chains that correspond to a single amino acid.
A single amino acid in a chain is both N and C terminus. There is no force
template for this case.
Args:
pdb_structure: An OpenMM pdb_structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
removed_chains = {}
for model in pdb_structure.iter_models():
valid_chains = [c for c in model.iter_chains() if len(c) > 1]
invalid_chain_ids = [
c.chain_id for c in model.iter_chains() if len(c) <= 1
]
model.chains = valid_chains
for chain_id in invalid_chain_ids:
model.chains_by_id.pop(chain_id)
removed_chains[model.number] = invalid_chain_ids
alterations_info["removed_chains"] = removed_chains
......@@ -21,60 +21,67 @@ import numpy as np
class AmberRelaxation(object):
"""Amber relaxation."""
"""Amber relaxation."""
def __init__(self,
*,
max_iterations: int,
tolerance: float,
stiffness: float,
exclude_residues: Sequence[int],
max_outer_iterations: int):
"""Initialize Amber Relaxer.
def __init__(
self,
*,
max_iterations: int,
tolerance: float,
stiffness: float,
exclude_residues: Sequence[int],
max_outer_iterations: int
):
"""Initialize Amber Relaxer.
Args:
max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
exclude_residues: Residues to exclude from per-atom restraining.
Zero-indexed.
max_outer_iterations: Maximum number of violation-informed relax
iterations. A value of 1 will run the non-iterative procedure used in
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
as soon as there are no violations, hence in most cases this causes no
slowdown. In the worst case we do 20 outer iterations.
"""
Args:
max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
exclude_residues: Residues to exclude from per-atom restraining.
Zero-indexed.
max_outer_iterations: Maximum number of violation-informed relax
iterations. A value of 1 will run the non-iterative procedure used in
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
as soon as there are no violations, hence in most cases this causes no
slowdown. In the worst case we do 20 outer iterations.
"""
self._max_iterations = max_iterations
self._tolerance = tolerance
self._stiffness = stiffness
self._exclude_residues = exclude_residues
self._max_outer_iterations = max_outer_iterations
self._max_iterations = max_iterations
self._tolerance = tolerance
self._stiffness = stiffness
self._exclude_residues = exclude_residues
self._max_outer_iterations = max_outer_iterations
def process(self, *,
prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
prot=prot, max_iterations=self._max_iterations,
tolerance=self._tolerance, stiffness=self._stiffness,
exclude_residues=self._exclude_residues,
max_outer_iterations=self._max_outer_iterations)
min_pos = out['pos']
start_pos = out['posinit']
rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0])
debug_data = {
'initial_energy': out['einit'],
'final_energy': out['efinal'],
'attempts': out['min_attempts'],
'rmsd': rmsd
}
pdb_str = amber_minimize.clean_protein(prot)
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
utils.assert_equal_nonterminal_atom_types(
protein.from_pdb_string(min_pdb).atom_mask,
prot.atom_mask)
violations = out['structural_violations'][
'total_per_residue_violations_mask']
return min_pdb, debug_data, violations
def process(
self, *, prot: protein.Protein
) -> Tuple[str, Dict[str, Any], np.ndarray]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
prot=prot,
max_iterations=self._max_iterations,
tolerance=self._tolerance,
stiffness=self._stiffness,
exclude_residues=self._exclude_residues,
max_outer_iterations=self._max_outer_iterations,
)
min_pos = out["pos"]
start_pos = out["posinit"]
rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
debug_data = {
"initial_energy": out["einit"],
"final_energy": out["efinal"],
"attempts": out["min_attempts"],
"rmsd": rmsd,
}
pdb_str = amber_minimize.clean_protein(prot)
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
utils.assert_equal_nonterminal_atom_types(
protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
)
violations = out["structural_violations"][
"total_per_residue_violations_mask"
]
return min_pdb, debug_data, violations
......@@ -23,59 +23,64 @@ from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
pdb_file = io.StringIO(pdb_str)
structure = PdbStructure(pdb_file)
topology = openmm_app.PDBFile(structure).getTopology()
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, pos, f)
return f.getvalue()
pdb_file = io.StringIO(pdb_str)
structure = PdbStructure(pdb_file)
topology = openmm_app.PDBFile(structure).getTopology()
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, pos, f)
return f.getvalue()
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
"""Overwrites the B-factors in pdb_str with contents of bfactors array.
"""Overwrites the B-factors in pdb_str with contents of bfactors array.
Args:
pdb_str: An input PDB string.
bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
B-factors are per residue; i.e. that the nonzero entries are identical in
[0, i, :].
Args:
pdb_str: An input PDB string.
bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
B-factors are per residue; i.e. that the nonzero entries are identical in
[0, i, :].
Returns:
A new PDB string with the B-factors replaced.
"""
if bfactors.shape[-1] != residue_constants.atom_type_num:
raise ValueError(
f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.')
Returns:
A new PDB string with the B-factors replaced.
"""
if bfactors.shape[-1] != residue_constants.atom_type_num:
raise ValueError(
f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}."
)
parser = PDB.PDBParser(QUIET=True)
handle = io.StringIO(pdb_str)
structure = parser.get_structure('', handle)
parser = PDB.PDBParser(QUIET=True)
handle = io.StringIO(pdb_str)
structure = parser.get_structure("", handle)
curr_resid = ('', '', '')
idx = -1
for atom in structure.get_atoms():
atom_resid = atom.parent.get_id()
if atom_resid != curr_resid:
idx += 1
if idx >= bfactors.shape[0]:
raise ValueError('Index into bfactors exceeds number of residues. '
'B-factors shape: {shape}, idx: {idx}.')
curr_resid = atom_resid
atom.bfactor = bfactors[idx, residue_constants.atom_order['CA']]
curr_resid = ("", "", "")
idx = -1
for atom in structure.get_atoms():
atom_resid = atom.parent.get_id()
if atom_resid != curr_resid:
idx += 1
if idx >= bfactors.shape[0]:
raise ValueError(
"Index into bfactors exceeds number of residues. "
"B-factors shape: {shape}, idx: {idx}."
)
curr_resid = atom_resid
atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]]
new_pdb = io.StringIO()
pdb_io = PDB.PDBIO()
pdb_io.set_structure(structure)
pdb_io.save(new_pdb)
return new_pdb.getvalue()
new_pdb = io.StringIO()
pdb_io = PDB.PDBIO()
pdb_io.set_structure(structure)
pdb_io.save(new_pdb)
return new_pdb.getvalue()
def assert_equal_nonterminal_atom_types(
atom_mask: np.ndarray, ref_atom_mask: np.ndarray):
"""Checks that pre- and post-minimized proteins have same atom set."""
# Ignore any terminal OXT atoms which may have been added by minimization.
oxt = residue_constants.atom_order['OXT']
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
no_oxt_mask[..., oxt] = False
np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask],
atom_mask[no_oxt_mask])
atom_mask: np.ndarray, ref_atom_mask: np.ndarray
):
"""Checks that pre- and post-minimized proteins have same atom set."""
# Ignore any terminal OXT atoms which may have been added by minimization.
oxt = residue_constants.atom_order["OXT"]
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
no_oxt_mask[..., oxt] = False
np.testing.assert_almost_equal(
ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
......@@ -32,32 +32,49 @@ ca_ca = 3.80209737096
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
chi_angles_atoms = {
'ALA': [],
"ALA": [],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']],
'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
'CYS': [['N', 'CA', 'CB', 'SG']],
'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'OE1']],
'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'OE1']],
'GLY': [],
'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']],
'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'],
['CB', 'CG', 'SD', 'CE']],
'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
'SER': [['N', 'CA', 'CB', 'OG']],
'THR': [['N', 'CA', 'CB', 'OG1']],
'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'VAL': [['N', 'CA', 'CB', 'CG1']],
"ARG": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "NE"],
["CG", "CD", "NE", "CZ"],
],
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"CYS": [["N", "CA", "CB", "SG"]],
"GLN": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLU": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLY": [],
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"LYS": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "CE"],
["CG", "CD", "CE", "NZ"],
],
"MET": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "SD"],
["CB", "CG", "SD", "CE"],
],
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
"SER": [["N", "CA", "CB", "OG"]],
"THR": [["N", "CA", "CB", "OG1"]],
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"VAL": [["N", "CA", "CB", "CG1"]],
}
# If chi angles given in fixed-length array, this matrix determines how to mask
......@@ -124,240 +141,266 @@ chi_pi_periodic = [
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# format: [atomname, group_idx, rel_position]
rigid_group_atom_positions = {
'ALA': [
['N', 0, (-0.525, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.529, -0.774, -1.205)],
['O', 3, (0.627, 1.062, 0.000)],
],
'ARG': [
['N', 0, (-0.524, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.524, -0.778, -1.209)],
['O', 3, (0.626, 1.062, 0.000)],
['CG', 4, (0.616, 1.390, -0.000)],
['CD', 5, (0.564, 1.414, 0.000)],
['NE', 6, (0.539, 1.357, -0.000)],
['NH1', 7, (0.206, 2.301, 0.000)],
['NH2', 7, (2.078, 0.978, -0.000)],
['CZ', 7, (0.758, 1.093, -0.000)],
],
'ASN': [
['N', 0, (-0.536, 1.357, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.531, -0.787, -1.200)],
['O', 3, (0.625, 1.062, 0.000)],
['CG', 4, (0.584, 1.399, 0.000)],
['ND2', 5, (0.593, -1.188, 0.001)],
['OD1', 5, (0.633, 1.059, 0.000)],
],
'ASP': [
['N', 0, (-0.525, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, 0.000, -0.000)],
['CB', 0, (-0.526, -0.778, -1.208)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.593, 1.398, -0.000)],
['OD1', 5, (0.610, 1.091, 0.000)],
['OD2', 5, (0.592, -1.101, -0.003)],
],
'CYS': [
['N', 0, (-0.522, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, 0.000, 0.000)],
['CB', 0, (-0.519, -0.773, -1.212)],
['O', 3, (0.625, 1.062, -0.000)],
['SG', 4, (0.728, 1.653, 0.000)],
],
'GLN': [
['N', 0, (-0.526, 1.361, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, 0.000)],
['CB', 0, (-0.525, -0.779, -1.207)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.615, 1.393, 0.000)],
['CD', 5, (0.587, 1.399, -0.000)],
['NE2', 6, (0.593, -1.189, -0.001)],
['OE1', 6, (0.634, 1.060, 0.000)],
],
'GLU': [
['N', 0, (-0.528, 1.361, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.526, -0.781, -1.207)],
['O', 3, (0.626, 1.062, 0.000)],
['CG', 4, (0.615, 1.392, 0.000)],
['CD', 5, (0.600, 1.397, 0.000)],
['OE1', 6, (0.607, 1.095, -0.000)],
['OE2', 6, (0.589, -1.104, -0.001)],
],
'GLY': [
['N', 0, (-0.572, 1.337, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.517, -0.000, -0.000)],
['O', 3, (0.626, 1.062, -0.000)],
],
'HIS': [
['N', 0, (-0.527, 1.360, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, 0.000, 0.000)],
['CB', 0, (-0.525, -0.778, -1.208)],
['O', 3, (0.625, 1.063, 0.000)],
['CG', 4, (0.600, 1.370, -0.000)],
['CD2', 5, (0.889, -1.021, 0.003)],
['ND1', 5, (0.744, 1.160, -0.000)],
['CE1', 5, (2.030, 0.851, 0.002)],
['NE2', 5, (2.145, -0.466, 0.004)],
],
'ILE': [
['N', 0, (-0.493, 1.373, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, -0.000)],
['CB', 0, (-0.536, -0.793, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG1', 4, (0.534, 1.437, -0.000)],
['CG2', 4, (0.540, -0.785, -1.199)],
['CD1', 5, (0.619, 1.391, 0.000)],
],
'LEU': [
['N', 0, (-0.520, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.522, -0.773, -1.214)],
['O', 3, (0.625, 1.063, -0.000)],
['CG', 4, (0.678, 1.371, 0.000)],
['CD1', 5, (0.530, 1.430, -0.000)],
['CD2', 5, (0.535, -0.774, 1.200)],
],
'LYS': [
['N', 0, (-0.526, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, 0.000)],
['CB', 0, (-0.524, -0.778, -1.208)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.619, 1.390, 0.000)],
['CD', 5, (0.559, 1.417, 0.000)],
['CE', 6, (0.560, 1.416, 0.000)],
['NZ', 7, (0.554, 1.387, 0.000)],
],
'MET': [
['N', 0, (-0.521, 1.364, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, 0.000, 0.000)],
['CB', 0, (-0.523, -0.776, -1.210)],
['O', 3, (0.625, 1.062, -0.000)],
['CG', 4, (0.613, 1.391, -0.000)],
['SD', 5, (0.703, 1.695, 0.000)],
['CE', 6, (0.320, 1.786, -0.000)],
],
'PHE': [
['N', 0, (-0.518, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, 0.000, -0.000)],
['CB', 0, (-0.525, -0.776, -1.212)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.607, 1.377, 0.000)],
['CD1', 5, (0.709, 1.195, -0.000)],
['CD2', 5, (0.706, -1.196, 0.000)],
['CE1', 5, (2.102, 1.198, -0.000)],
['CE2', 5, (2.098, -1.201, -0.000)],
['CZ', 5, (2.794, -0.003, -0.001)],
],
'PRO': [
['N', 0, (-0.566, 1.351, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, 0.000)],
['CB', 0, (-0.546, -0.611, -1.293)],
['O', 3, (0.621, 1.066, 0.000)],
['CG', 4, (0.382, 1.445, 0.0)],
"ALA": [
["N", 0, (-0.525, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.529, -0.774, -1.205)],
["O", 3, (0.627, 1.062, 0.000)],
],
"ARG": [
["N", 0, (-0.524, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.524, -0.778, -1.209)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.616, 1.390, -0.000)],
["CD", 5, (0.564, 1.414, 0.000)],
["NE", 6, (0.539, 1.357, -0.000)],
["NH1", 7, (0.206, 2.301, 0.000)],
["NH2", 7, (2.078, 0.978, -0.000)],
["CZ", 7, (0.758, 1.093, -0.000)],
],
"ASN": [
["N", 0, (-0.536, 1.357, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.531, -0.787, -1.200)],
["O", 3, (0.625, 1.062, 0.000)],
["CG", 4, (0.584, 1.399, 0.000)],
["ND2", 5, (0.593, -1.188, 0.001)],
["OD1", 5, (0.633, 1.059, 0.000)],
],
"ASP": [
["N", 0, (-0.525, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, 0.000, -0.000)],
["CB", 0, (-0.526, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.593, 1.398, -0.000)],
["OD1", 5, (0.610, 1.091, 0.000)],
["OD2", 5, (0.592, -1.101, -0.003)],
],
"CYS": [
["N", 0, (-0.522, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, 0.000)],
["CB", 0, (-0.519, -0.773, -1.212)],
["O", 3, (0.625, 1.062, -0.000)],
["SG", 4, (0.728, 1.653, 0.000)],
],
"GLN": [
["N", 0, (-0.526, 1.361, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.525, -0.779, -1.207)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.615, 1.393, 0.000)],
["CD", 5, (0.587, 1.399, -0.000)],
["NE2", 6, (0.593, -1.189, -0.001)],
["OE1", 6, (0.634, 1.060, 0.000)],
],
"GLU": [
["N", 0, (-0.528, 1.361, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.526, -0.781, -1.207)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.615, 1.392, 0.000)],
["CD", 5, (0.600, 1.397, 0.000)],
["OE1", 6, (0.607, 1.095, -0.000)],
["OE2", 6, (0.589, -1.104, -0.001)],
],
"GLY": [
["N", 0, (-0.572, 1.337, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.517, -0.000, -0.000)],
["O", 3, (0.626, 1.062, -0.000)],
],
"HIS": [
["N", 0, (-0.527, 1.360, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.525, -0.778, -1.208)],
["O", 3, (0.625, 1.063, 0.000)],
["CG", 4, (0.600, 1.370, -0.000)],
["CD2", 5, (0.889, -1.021, 0.003)],
["ND1", 5, (0.744, 1.160, -0.000)],
["CE1", 5, (2.030, 0.851, 0.002)],
["NE2", 5, (2.145, -0.466, 0.004)],
],
"ILE": [
["N", 0, (-0.493, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.536, -0.793, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.534, 1.437, -0.000)],
["CG2", 4, (0.540, -0.785, -1.199)],
["CD1", 5, (0.619, 1.391, 0.000)],
],
"LEU": [
["N", 0, (-0.520, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.522, -0.773, -1.214)],
["O", 3, (0.625, 1.063, -0.000)],
["CG", 4, (0.678, 1.371, 0.000)],
["CD1", 5, (0.530, 1.430, -0.000)],
["CD2", 5, (0.535, -0.774, 1.200)],
],
"LYS": [
["N", 0, (-0.526, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.524, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.619, 1.390, 0.000)],
["CD", 5, (0.559, 1.417, 0.000)],
["CE", 6, (0.560, 1.416, 0.000)],
["NZ", 7, (0.554, 1.387, 0.000)],
],
"MET": [
["N", 0, (-0.521, 1.364, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.210)],
["O", 3, (0.625, 1.062, -0.000)],
["CG", 4, (0.613, 1.391, -0.000)],
["SD", 5, (0.703, 1.695, 0.000)],
["CE", 6, (0.320, 1.786, -0.000)],
],
"PHE": [
["N", 0, (-0.518, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, -0.000)],
["CB", 0, (-0.525, -0.776, -1.212)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.607, 1.377, 0.000)],
["CD1", 5, (0.709, 1.195, -0.000)],
["CD2", 5, (0.706, -1.196, 0.000)],
["CE1", 5, (2.102, 1.198, -0.000)],
["CE2", 5, (2.098, -1.201, -0.000)],
["CZ", 5, (2.794, -0.003, -0.001)],
],
"PRO": [
["N", 0, (-0.566, 1.351, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, 0.000)],
["CB", 0, (-0.546, -0.611, -1.293)],
["O", 3, (0.621, 1.066, 0.000)],
["CG", 4, (0.382, 1.445, 0.0)],
# ['CD', 5, (0.427, 1.440, 0.0)],
['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
],
'SER': [
['N', 0, (-0.529, 1.360, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.518, -0.777, -1.211)],
['O', 3, (0.626, 1.062, -0.000)],
['OG', 4, (0.503, 1.325, 0.000)],
],
'THR': [
['N', 0, (-0.517, 1.364, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, -0.000)],
['CB', 0, (-0.516, -0.793, -1.215)],
['O', 3, (0.626, 1.062, 0.000)],
['CG2', 4, (0.550, -0.718, -1.228)],
['OG1', 4, (0.472, 1.353, 0.000)],
],
'TRP': [
['N', 0, (-0.521, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, 0.000)],
['CB', 0, (-0.523, -0.776, -1.212)],
['O', 3, (0.627, 1.062, 0.000)],
['CG', 4, (0.609, 1.370, -0.000)],
['CD1', 5, (0.824, 1.091, 0.000)],
['CD2', 5, (0.854, -1.148, -0.005)],
['CE2', 5, (2.186, -0.678, -0.007)],
['CE3', 5, (0.622, -2.530, -0.007)],
['NE1', 5, (2.140, 0.690, -0.004)],
['CH2', 5, (3.028, -2.890, -0.013)],
['CZ2', 5, (3.283, -1.543, -0.011)],
['CZ3', 5, (1.715, -3.389, -0.011)],
],
'TYR': [
['N', 0, (-0.522, 1.362, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, -0.000, -0.000)],
['CB', 0, (-0.522, -0.776, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG', 4, (0.607, 1.382, -0.000)],
['CD1', 5, (0.716, 1.195, -0.000)],
['CD2', 5, (0.713, -1.194, -0.001)],
['CE1', 5, (2.107, 1.200, -0.002)],
['CE2', 5, (2.104, -1.201, -0.003)],
['OH', 5, (4.168, -0.002, -0.005)],
['CZ', 5, (2.791, -0.001, -0.003)],
],
'VAL': [
['N', 0, (-0.494, 1.373, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, -0.000)],
['CB', 0, (-0.533, -0.795, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG1', 4, (0.540, 1.429, -0.000)],
['CG2', 4, (0.533, -0.776, 1.203)],
["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
],
"SER": [
["N", 0, (-0.529, 1.360, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.518, -0.777, -1.211)],
["O", 3, (0.626, 1.062, -0.000)],
["OG", 4, (0.503, 1.325, 0.000)],
],
"THR": [
["N", 0, (-0.517, 1.364, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, -0.000)],
["CB", 0, (-0.516, -0.793, -1.215)],
["O", 3, (0.626, 1.062, 0.000)],
["CG2", 4, (0.550, -0.718, -1.228)],
["OG1", 4, (0.472, 1.353, 0.000)],
],
"TRP": [
["N", 0, (-0.521, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.212)],
["O", 3, (0.627, 1.062, 0.000)],
["CG", 4, (0.609, 1.370, -0.000)],
["CD1", 5, (0.824, 1.091, 0.000)],
["CD2", 5, (0.854, -1.148, -0.005)],
["CE2", 5, (2.186, -0.678, -0.007)],
["CE3", 5, (0.622, -2.530, -0.007)],
["NE1", 5, (2.140, 0.690, -0.004)],
["CH2", 5, (3.028, -2.890, -0.013)],
["CZ2", 5, (3.283, -1.543, -0.011)],
["CZ3", 5, (1.715, -3.389, -0.011)],
],
"TYR": [
["N", 0, (-0.522, 1.362, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, -0.000, -0.000)],
["CB", 0, (-0.522, -0.776, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG", 4, (0.607, 1.382, -0.000)],
["CD1", 5, (0.716, 1.195, -0.000)],
["CD2", 5, (0.713, -1.194, -0.001)],
["CE1", 5, (2.107, 1.200, -0.002)],
["CE2", 5, (2.104, -1.201, -0.003)],
["OH", 5, (4.168, -0.002, -0.005)],
["CZ", 5, (2.791, -0.001, -0.003)],
],
"VAL": [
["N", 0, (-0.494, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.533, -0.795, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.540, 1.429, -0.000)],
["CG2", 4, (0.533, -0.776, 1.203)],
],
}
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
residue_atoms = {
'ALA': ['C', 'CA', 'CB', 'N', 'O'],
'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'],
'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'],
'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'],
'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'],
'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'],
'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'],
'GLY': ['C', 'CA', 'N', 'O'],
'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'],
'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'],
'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'],
'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'],
'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'],
'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'],
'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'],
'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'],
'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'],
'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3',
'CH2', 'N', 'NE1', 'O'],
'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O',
'OH'],
'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O']
"ALA": ["C", "CA", "CB", "N", "O"],
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
"GLY": ["C", "CA", "N", "O"],
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
"SER": ["C", "CA", "CB", "N", "O", "OG"],
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
"TRP": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
"N",
"NE1",
"O",
],
"TYR": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"N",
"O",
"OH",
],
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
}
# Naming swaps for ambiguous atom names.
......@@ -368,115 +411,134 @@ residue_atoms = {
# the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
residue_atom_renaming_swaps = {
'ASP': {'OD1': 'OD2'},
'GLU': {'OE1': 'OE2'},
'PHE': {'CD1': 'CD2', 'CE1': 'CE2'},
'TYR': {'CD1': 'CD2', 'CE1': 'CE2'},
"ASP": {"OD1": "OD2"},
"GLU": {"OE1": "OE2"},
"PHE": {"CD1": "CD2", "CE1": "CE2"},
"TYR": {"CD1": "CD2", "CE1": "CE2"},
}
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
van_der_waals_radius = {
'C': 1.7,
'N': 1.55,
'O': 1.52,
'S': 1.8,
"C": 1.7,
"N": 1.55,
"O": 1.52,
"S": 1.8,
}
Bond = collections.namedtuple(
'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev'])
"Bond", ["atom1_name", "atom2_name", "length", "stddev"]
)
BondAngle = collections.namedtuple(
'BondAngle',
['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev'])
"BondAngle",
["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
)
@functools.lru_cache(maxsize=None)
def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]],
Mapping[str, List[Bond]],
Mapping[str, List[BondAngle]]]:
"""Load stereo_chemical_props.txt into a nice structure.
Load literature values for bond lengths and bond angles and translate
bond angles into the length of the opposite edge of the triangle
("residue_virtual_bonds").
Returns:
residue_bonds: dict that maps resname --> list of Bond tuples
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
"""
# TODO: this file should be downloaded in a setup script
stereo_chemical_props_path = (
'openfold/resources/stereo_chemical_props.txt')
with open(stereo_chemical_props_path, 'rt') as f:
stereo_chemical_props = f.read()
lines_iter = iter(stereo_chemical_props.splitlines())
# Load bond lengths.
residue_bonds = {}
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == '-':
break
bond, resname, length, stddev = line.split()
atom1, atom2 = bond.split('-')
if resname not in residue_bonds:
residue_bonds[resname] = []
residue_bonds[resname].append(
Bond(atom1, atom2, float(length), float(stddev)))
residue_bonds['UNK'] = []
# Load bond angles.
residue_bond_angles = {}
next(lines_iter) # Skip empty line.
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == '-':
break
bond, resname, angle_degree, stddev_degree = line.split()
atom1, atom2, atom3 = bond.split('-')
if resname not in residue_bond_angles:
residue_bond_angles[resname] = []
residue_bond_angles[resname].append(
BondAngle(atom1, atom2, atom3,
float(angle_degree) / 180. * np.pi,
float(stddev_degree) / 180. * np.pi))
residue_bond_angles['UNK'] = []
def make_bond_key(atom1_name, atom2_name):
"""Unique key to lookup bonds."""
return '-'.join(sorted([atom1_name, atom2_name]))
# Translate bond angles into distances ("virtual bonds").
residue_virtual_bonds = {}
for resname, bond_angles in residue_bond_angles.items():
# Create a fast lookup dict for bond lengths.
bond_cache = {}
for b in residue_bonds[resname]:
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
residue_virtual_bonds[resname] = []
for ba in bond_angles:
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
# Compute distance between atom1 and atom3 using the law of cosines
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
gamma = ba.angle_rad
length = np.sqrt(bond1.length**2 + bond2.length**2
- 2 * bond1.length * bond2.length * np.cos(gamma))
# Propagation of uncertainty assuming uncorrelated errors.
dl_outer = 0.5 / length
dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
stddev = np.sqrt((dl_dgamma * ba.stddev)**2 +
(dl_db1 * bond1.stddev)**2 +
(dl_db2 * bond2.stddev)**2)
residue_virtual_bonds[resname].append(
Bond(ba.atom1_name, ba.atom3name, length, stddev))
return (residue_bonds,
residue_virtual_bonds,
residue_bond_angles)
def load_stereo_chemical_props() -> Tuple[
Mapping[str, List[Bond]],
Mapping[str, List[Bond]],
Mapping[str, List[BondAngle]],
]:
"""Load stereo_chemical_props.txt into a nice structure.
Load literature values for bond lengths and bond angles and translate
bond angles into the length of the opposite edge of the triangle
("residue_virtual_bonds").
Returns:
residue_bonds: dict that maps resname --> list of Bond tuples
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
"""
# TODO: this file should be downloaded in a setup script
stereo_chemical_props_path = "openfold/resources/stereo_chemical_props.txt"
with open(stereo_chemical_props_path, "rt") as f:
stereo_chemical_props = f.read()
lines_iter = iter(stereo_chemical_props.splitlines())
# Load bond lengths.
residue_bonds = {}
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, length, stddev = line.split()
atom1, atom2 = bond.split("-")
if resname not in residue_bonds:
residue_bonds[resname] = []
residue_bonds[resname].append(
Bond(atom1, atom2, float(length), float(stddev))
)
residue_bonds["UNK"] = []
# Load bond angles.
residue_bond_angles = {}
next(lines_iter) # Skip empty line.
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, angle_degree, stddev_degree = line.split()
atom1, atom2, atom3 = bond.split("-")
if resname not in residue_bond_angles:
residue_bond_angles[resname] = []
residue_bond_angles[resname].append(
BondAngle(
atom1,
atom2,
atom3,
float(angle_degree) / 180.0 * np.pi,
float(stddev_degree) / 180.0 * np.pi,
)
)
residue_bond_angles["UNK"] = []
def make_bond_key(atom1_name, atom2_name):
"""Unique key to lookup bonds."""
return "-".join(sorted([atom1_name, atom2_name]))
# Translate bond angles into distances ("virtual bonds").
residue_virtual_bonds = {}
for resname, bond_angles in residue_bond_angles.items():
# Create a fast lookup dict for bond lengths.
bond_cache = {}
for b in residue_bonds[resname]:
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
residue_virtual_bonds[resname] = []
for ba in bond_angles:
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
# Compute distance between atom1 and atom3 using the law of cosines
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
gamma = ba.angle_rad
length = np.sqrt(
bond1.length ** 2
+ bond2.length ** 2
- 2 * bond1.length * bond2.length * np.cos(gamma)
)
# Propagation of uncertainty assuming uncorrelated errors.
dl_outer = 0.5 / length
dl_dgamma = (
2 * bond1.length * bond2.length * np.sin(gamma)
) * dl_outer
dl_db1 = (
2 * bond1.length - 2 * bond2.length * np.cos(gamma)
) * dl_outer
dl_db2 = (
2 * bond2.length - 2 * bond1.length * np.cos(gamma)
) * dl_outer
stddev = np.sqrt(
(dl_dgamma * ba.stddev) ** 2
+ (dl_db1 * bond1.stddev) ** 2
+ (dl_db2 * bond2.stddev) ** 2
)
residue_virtual_bonds[resname].append(
Bond(ba.atom1_name, ba.atom3name, length, stddev)
)
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
# Between-residue bond lengths for general bonds (first element) and for Proline
......@@ -491,10 +553,43 @@ between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
# This mapping is used when we need to store atom data in a format that requires
# fixed atom data size for every residue (e.g. a numpy array).
atom_types = [
'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
'CZ3', 'NZ', 'OXT'
"N",
"CA",
"C",
"CB",
"O",
"CG",
"CG1",
"CG2",
"OG",
"OG1",
"SG",
"CD",
"CD1",
"CD2",
"ND1",
"ND2",
"OD1",
"OD2",
"SD",
"CE",
"CE1",
"CE2",
"CE3",
"NE",
"NE1",
"NE2",
"OE1",
"OE2",
"CH2",
"NH1",
"NH2",
"OH",
"CZ",
"CZ2",
"CZ3",
"NZ",
"OXT",
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types) # := 37.
......@@ -503,28 +598,252 @@ atom_type_num = len(atom_types) # := 37.
# pylint: disable=line-too-long
# pylint: disable=bad-whitespace
restype_name_to_atom14_names = {
'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''],
'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''],
'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''],
'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''],
'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
"ARG": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"NE",
"CZ",
"NH1",
"NH2",
"",
"",
"",
],
"ASN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"OD1",
"ND2",
"",
"",
"",
"",
"",
"",
],
"ASP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"OD1",
"OD2",
"",
"",
"",
"",
"",
"",
],
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
"GLN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"NE2",
"",
"",
"",
"",
"",
],
"GLU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"OE2",
"",
"",
"",
"",
"",
],
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
"HIS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"ND1",
"CD2",
"CE1",
"NE2",
"",
"",
"",
"",
],
"ILE": [
"N",
"CA",
"C",
"O",
"CB",
"CG1",
"CG2",
"CD1",
"",
"",
"",
"",
"",
"",
],
"LEU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"",
"",
"",
"",
"",
"",
],
"LYS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"CE",
"NZ",
"",
"",
"",
"",
"",
],
"MET": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"SD",
"CE",
"",
"",
"",
"",
"",
"",
],
"PHE": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"",
"",
"",
],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
"THR": [
"N",
"CA",
"C",
"O",
"CB",
"OG1",
"CG2",
"",
"",
"",
"",
"",
"",
"",
],
"TRP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"NE1",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
],
"TYR": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"OH",
"",
"",
],
"VAL": [
"N",
"CA",
"C",
"O",
"CB",
"CG1",
"CG2",
"",
"",
"",
"",
"",
"",
"",
],
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
}
# pylint: enable=line-too-long
# pylint: enable=bad-whitespace
......@@ -533,81 +852,102 @@ restype_name_to_atom14_names = {
# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes = [
'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
'S', 'T', 'W', 'Y', 'V'
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes) # := 20.
unk_restype_index = restype_num # Catch-all index for unknown restypes.
restypes_with_x = restypes + ['X']
restypes_with_x = restypes + ["X"]
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
def sequence_to_onehot(
sequence: str,
mapping: Mapping[str, int],
map_unknown_to_x: bool = False) -> np.ndarray:
"""Maps the given sequence into a one-hot encoded matrix.
Args:
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
amino acid 'X', an error will be thrown. If False, any amino acid not in
the mapping will throw an error.
Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
the sequence.
Raises:
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1
if sorted(set(mapping.values())) != list(range(num_entries)):
raise ValueError('The mapping must have values from 0 to num_unique_aas-1 '
'without any gaps. Got: %s' % sorted(mapping.values()))
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
for aa_index, aa_type in enumerate(sequence):
if map_unknown_to_x:
if aa_type.isalpha() and aa_type.isupper():
aa_id = mapping.get(aa_type, mapping['X'])
else:
raise ValueError(f'Invalid character in the sequence: {aa_type}')
else:
aa_id = mapping[aa_type]
one_hot_arr[aa_index, aa_id] = 1
return one_hot_arr
sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
) -> np.ndarray:
"""Maps the given sequence into a one-hot encoded matrix.
Args:
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
amino acid 'X', an error will be thrown. If False, any amino acid not in
the mapping will throw an error.
Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
the sequence.
Raises:
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1
if sorted(set(mapping.values())) != list(range(num_entries)):
raise ValueError(
"The mapping must have values from 0 to num_unique_aas-1 "
"without any gaps. Got: %s" % sorted(mapping.values())
)
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
for aa_index, aa_type in enumerate(sequence):
if map_unknown_to_x:
if aa_type.isalpha() and aa_type.isupper():
aa_id = mapping.get(aa_type, mapping["X"])
else:
raise ValueError(
f"Invalid character in the sequence: {aa_type}"
)
else:
aa_id = mapping[aa_type]
one_hot_arr[aa_index, aa_id] = 1
return one_hot_arr
restype_1to3 = {
'A': 'ALA',
'R': 'ARG',
'N': 'ASN',
'D': 'ASP',
'C': 'CYS',
'Q': 'GLN',
'E': 'GLU',
'G': 'GLY',
'H': 'HIS',
'I': 'ILE',
'L': 'LEU',
'K': 'LYS',
'M': 'MET',
'F': 'PHE',
'P': 'PRO',
'S': 'SER',
'T': 'THR',
'W': 'TRP',
'Y': 'TYR',
'V': 'VAL',
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
}
......@@ -618,7 +958,7 @@ restype_1to3 = {
restype_3to1 = {v: k for k, v in restype_1to3.items()}
# Define a restype name for all unknown residues.
unk_restype = 'UNK'
unk_restype = "UNK"
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
......@@ -632,78 +972,79 @@ resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
# codes is put at the end (20 and 21) so that they can easily be ignored if
# desired.
HHBLITS_AA_TO_ID = {
'A': 0,
'B': 2,
'C': 1,
'D': 2,
'E': 3,
'F': 4,
'G': 5,
'H': 6,
'I': 7,
'J': 20,
'K': 8,
'L': 9,
'M': 10,
'N': 11,
'O': 20,
'P': 12,
'Q': 13,
'R': 14,
'S': 15,
'T': 16,
'U': 1,
'V': 17,
'W': 18,
'X': 20,
'Y': 19,
'Z': 3,
'-': 21,
"A": 0,
"B": 2,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"J": 20,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"O": 20,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"U": 1,
"V": 17,
"W": 18,
"X": 20,
"Y": 19,
"Z": 3,
"-": 21,
}
# Partial inversion of HHBLITS_AA_TO_ID.
ID_TO_HHBLITS_AA = {
0: 'A',
1: 'C', # Also U.
2: 'D', # Also B.
3: 'E', # Also Z.
4: 'F',
5: 'G',
6: 'H',
7: 'I',
8: 'K',
9: 'L',
10: 'M',
11: 'N',
12: 'P',
13: 'Q',
14: 'R',
15: 'S',
16: 'T',
17: 'V',
18: 'W',
19: 'Y',
20: 'X', # Includes J and O.
21: '-',
0: "A",
1: "C", # Also U.
2: "D", # Also B.
3: "E", # Also Z.
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X", # Includes J and O.
21: "-",
}
restypes_with_x_and_gap = restypes + ['X', '-']
restypes_with_x_and_gap = restypes + ["X", "-"]
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
for i in range(len(restypes_with_x_and_gap)))
for i in range(len(restypes_with_x_and_gap))
)
def _make_standard_atom_mask() -> np.ndarray:
"""Returns [num_res_types, num_atom_types] mask array."""
# +1 to account for unknown (all 0s).
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
mask[restype, atom_type] = 1
return mask
"""Returns [num_res_types, num_atom_types] mask array."""
# +1 to account for unknown (all 0s).
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
mask[restype, atom_type] = 1
return mask
STANDARD_ATOM_MASK = _make_standard_atom_mask()
......@@ -712,25 +1053,26 @@ STANDARD_ATOM_MASK = _make_standard_atom_mask()
# A one hot representation for the first and second atoms defining the axis
# of rotation for each chi-angle in each residue.
def chi_angle_atom(atom_index: int) -> np.ndarray:
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index = {}
one_hots = []
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index = {}
one_hots = []
for k, v in chi_angles_atoms.items():
indices = [atom_types.index(s[atom_index]) for s in v]
indices.extend([-1]*(4-len(indices)))
chi_angles_index[k] = indices
for k, v in chi_angles_atoms.items():
indices = [atom_types.index(s[atom_index]) for s in v]
indices.extend([-1] * (4 - len(indices)))
chi_angles_index[k] = indices
for r in restypes:
res3 = restype_1to3[r]
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
one_hots.append(one_hot)
for r in restypes:
res3 = restype_1to3[r]
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
one_hots.append(one_hot)
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
one_hot = np.stack(one_hots, axis=0)
one_hot = np.transpose(one_hot, [0, 2, 1])
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
one_hot = np.stack(one_hots, axis=0)
one_hot = np.transpose(one_hot, [0, 2, 1])
return one_hot
return one_hot
chi_atom_1_one_hot = chi_angle_atom(1)
chi_atom_2_one_hot = chi_angle_atom(2)
......@@ -738,35 +1080,41 @@ chi_atom_2_one_hot = chi_angle_atom(2)
# An array like chi_angles_atoms but using indices rather than names.
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
chi_angles_atom_indices = tree.map_structure(
lambda atom_name: atom_order[atom_name], chi_angles_atom_indices)
chi_angles_atom_indices = np.array([
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
for chi_atoms in chi_angles_atom_indices])
lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
)
chi_angles_atom_indices = np.array(
[
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
for chi_atoms in chi_angles_atom_indices
]
)
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
# and atom index within that group.
chi_groups_for_atom = collections.defaultdict(list)
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
chi_groups_for_atom = dict(chi_groups_for_atom)
def _make_rigid_transformation_4x4(ex, ey, translation):
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
ex_normalized = ex / np.linalg.norm(ex)
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
ex_normalized = ex / np.linalg.norm(ex)
# make ey perpendicular to ex
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= np.linalg.norm(ey_normalized)
# make ey perpendicular to ex
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= np.linalg.norm(ey_normalized)
# compute ez as cross product
eznorm = np.cross(ex_normalized, ey_normalized)
m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0)
return m
# compute ez as cross product
eznorm = np.cross(ex_normalized, ey_normalized)
m = np.stack(
[ex_normalized, ey_normalized, eznorm, translation]
).transpose()
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
return m
# create an array with (restype, atomtype) --> rigid_group_idx
......@@ -783,138 +1131,173 @@ restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
def _make_rigid_group_constants():
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_position in rigid_group_atom_positions[
resname]:
atomtype = atom_order[atomname]
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_rigid_group_positions[restype,
atom14idx, :] = atom_position
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions = {name: np.array(pos) for name, _, pos
in rigid_group_atom_positions[resname]}
# backbone to backbone is the identity transform
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
# pre-omega-frame to backbone (currently dummy identity matrix)
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
# phi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions['N'] - atom_positions['CA'],
ey=np.array([1., 0., 0.]),
translation=atom_positions['N'])
restype_rigid_group_default_frame[restype, 2, :, :] = mat
# psi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions['C'] - atom_positions['CA'],
ey=atom_positions['CA'] - atom_positions['N'],
translation=atom_positions['C'])
restype_rigid_group_default_frame[restype, 3, :, :] = mat
# chi1-frame to backbone
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [atom_positions[name] for name in base_atom_names]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2])
restype_rigid_group_default_frame[restype, 4, :, :] = mat
# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_position in rigid_group_atom_positions[
resname
]:
atomtype = atom_order[atomname]
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_rigid_group_positions[
restype, atomtype, :
] = atom_position
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_rigid_group_positions[
restype, atom14idx, :
] = atom_position
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions = {
name: np.array(pos)
for name, _, pos in rigid_group_atom_positions[resname]
}
# backbone to backbone is the identity transform
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
# pre-omega-frame to backbone (currently dummy identity matrix)
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
# phi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=np.array([-1., 0., 0.]),
translation=axis_end_atom_position)
restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
ex=atom_positions["N"] - atom_positions["CA"],
ey=np.array([1.0, 0.0, 0.0]),
translation=atom_positions["N"],
)
restype_rigid_group_default_frame[restype, 2, :, :] = mat
# psi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["C"] - atom_positions["CA"],
ey=atom_positions["CA"] - atom_positions["N"],
translation=atom_positions["C"],
)
restype_rigid_group_default_frame[restype, 3, :, :] = mat
# chi1-frame to backbone
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [
atom_positions[name] for name in base_atom_names
]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2],
)
restype_rigid_group_default_frame[restype, 4, :, :] = mat
# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=np.array([-1.0, 0.0, 0.0]),
translation=axis_end_atom_position,
)
restype_rigid_group_default_frame[
restype, 4 + chi_idx, :, :
] = mat
_make_rigid_group_constants()
def make_atom14_dists_bounds(overlap_tolerance=1.5,
bond_length_tolerance_factor=15):
"""compute upper and lower bounds for bonds to assess violations."""
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_list = restype_name_to_atom14_names[resname]
# create lower and upper bounds for clashes
for atom1_idx, atom1_name in enumerate(atom_list):
if not atom1_name:
continue
atom1_radius = van_der_waals_radius[atom1_name[0]]
for atom2_idx, atom2_name in enumerate(atom_list):
if (not atom2_name) or atom1_idx == atom2_idx:
continue
atom2_radius = van_der_waals_radius[atom2_name[0]]
lower = atom1_radius + atom2_radius - overlap_tolerance
upper = 1e10
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
# overwrite lower and upper bounds for bonds and angles
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
atom1_idx = atom_list.index(b.atom1_name)
atom2_idx = atom_list.index(b.atom2_name)
lower = b.length - bond_length_tolerance_factor * b.stddev
upper = b.length + bond_length_tolerance_factor * b.stddev
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14)
'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14)
'stddev': restype_atom14_bond_stddev, # shape (21,14,14)
}
def make_atom14_dists_bounds(
overlap_tolerance=1.5, bond_length_tolerance_factor=15
):
"""compute upper and lower bounds for bonds to assess violations."""
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_list = restype_name_to_atom14_names[resname]
# create lower and upper bounds for clashes
for atom1_idx, atom1_name in enumerate(atom_list):
if not atom1_name:
continue
atom1_radius = van_der_waals_radius[atom1_name[0]]
for atom2_idx, atom2_name in enumerate(atom_list):
if (not atom2_name) or atom1_idx == atom2_idx:
continue
atom2_radius = van_der_waals_radius[atom2_name[0]]
lower = atom1_radius + atom2_radius - overlap_tolerance
upper = 1e10
restype_atom14_bond_lower_bound[
restype, atom1_idx, atom2_idx
] = lower
restype_atom14_bond_lower_bound[
restype, atom2_idx, atom1_idx
] = lower
restype_atom14_bond_upper_bound[
restype, atom1_idx, atom2_idx
] = upper
restype_atom14_bond_upper_bound[
restype, atom2_idx, atom1_idx
] = upper
# overwrite lower and upper bounds for bonds and angles
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
atom1_idx = atom_list.index(b.atom1_name)
atom2_idx = atom_list.index(b.atom2_name)
lower = b.length - bond_length_tolerance_factor * b.stddev
upper = b.length + bond_length_tolerance_factor * b.stddev
restype_atom14_bond_lower_bound[
restype, atom1_idx, atom2_idx
] = lower
restype_atom14_bond_lower_bound[
restype, atom2_idx, atom1_idx
] = lower
restype_atom14_bond_upper_bound[
restype, atom1_idx, atom2_idx
] = upper
restype_atom14_bond_upper_bound[
restype, atom2_idx, atom1_idx
] = upper
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
return {
"lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
"upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
"stddev": restype_atom14_bond_stddev, # shape (21,14,14)
}
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = (
np.tile(np.arange(14, dtype=np.int), (21, 1))
restype_atom14_ambiguous_atoms_swap_idx = np.tile(
np.arange(14, dtype=np.int), (21, 1)
)
def _make_atom14_ambiguity_feats():
for res, pairs in residue_atom_renaming_swaps.items():
res_idx = restype_order[restype_3to1[res]]
for atom1, atom2 in pairs.items():
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = (
atom2_idx
)
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = (
atom1_idx
)
restype_atom14_ambiguous_atoms_swap_idx[
res_idx, atom1_idx
] = atom2_idx
restype_atom14_ambiguous_atoms_swap_idx[
res_idx, atom2_idx
] = atom1_idx
_make_atom14_ambiguity_feats()
......@@ -3,13 +3,14 @@ import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [os.path.basename(f)[:-3] for f in _files if os.path.isfile(f) and not f.endswith("__init__.py")]
_modules = [(m, importlib.import_module('.' + m, __name__)) for m in __all__]
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
......@@ -18,21 +18,48 @@ import torch
def rot_matmul(a, b):
row_1 = torch.stack([
a[...,0,0]*b[...,0,0] + a[...,0,1]*b[...,1,0] + a[...,0,2]*b[...,2,0],
a[...,0,0]*b[...,0,1] + a[...,0,1]*b[...,1,1] + a[...,0,2]*b[...,2,1],
a[...,0,0]*b[...,0,2] + a[...,0,1]*b[...,1,2] + a[...,0,2]*b[...,2,2],
], dim=-1)
row_2 = torch.stack([
a[...,1,0]*b[...,0,0] + a[...,1,1]*b[...,1,0] + a[...,1,2]*b[...,2,0],
a[...,1,0]*b[...,0,1] + a[...,1,1]*b[...,1,1] + a[...,1,2]*b[...,2,1],
a[...,1,0]*b[...,0,2] + a[...,1,1]*b[...,1,2] + a[...,1,2]*b[...,2,2],
], dim=-1)
row_3 = torch.stack([
a[...,2,0]*b[...,0,0] + a[...,2,1]*b[...,1,0] + a[...,2,2]*b[...,2,0],
a[...,2,0]*b[...,0,1] + a[...,2,1]*b[...,1,1] + a[...,2,2]*b[...,2,1],
a[...,2,0]*b[...,0,2] + a[...,2,1]*b[...,1,2] + a[...,2,2]*b[...,2,2],
], dim=-1)
row_1 = torch.stack(
[
a[..., 0, 0] * b[..., 0, 0]
+ a[..., 0, 1] * b[..., 1, 0]
+ a[..., 0, 2] * b[..., 2, 0],
a[..., 0, 0] * b[..., 0, 1]
+ a[..., 0, 1] * b[..., 1, 1]
+ a[..., 0, 2] * b[..., 2, 1],
a[..., 0, 0] * b[..., 0, 2]
+ a[..., 0, 1] * b[..., 1, 2]
+ a[..., 0, 2] * b[..., 2, 2],
],
dim=-1,
)
row_2 = torch.stack(
[
a[..., 1, 0] * b[..., 0, 0]
+ a[..., 1, 1] * b[..., 1, 0]
+ a[..., 1, 2] * b[..., 2, 0],
a[..., 1, 0] * b[..., 0, 1]
+ a[..., 1, 1] * b[..., 1, 1]
+ a[..., 1, 2] * b[..., 2, 1],
a[..., 1, 0] * b[..., 0, 2]
+ a[..., 1, 1] * b[..., 1, 2]
+ a[..., 1, 2] * b[..., 2, 2],
],
dim=-1,
)
row_3 = torch.stack(
[
a[..., 2, 0] * b[..., 0, 0]
+ a[..., 2, 1] * b[..., 1, 0]
+ a[..., 2, 2] * b[..., 2, 0],
a[..., 2, 0] * b[..., 0, 1]
+ a[..., 2, 1] * b[..., 1, 1]
+ a[..., 2, 2] * b[..., 2, 1],
a[..., 2, 0] * b[..., 0, 2]
+ a[..., 2, 1] * b[..., 1, 2]
+ a[..., 2, 2] * b[..., 2, 2],
],
dim=-1,
)
return torch.stack([row_1, row_2, row_3], dim=-2)
......@@ -41,52 +68,56 @@ def rot_vec_mul(r, t):
x = t[..., 0]
y = t[..., 1]
z = t[..., 2]
return torch.stack([
r[..., 0, 0]*x + r[..., 0, 1]*y + r[..., 0, 2]*z,
r[..., 1, 0]*x + r[..., 1, 1]*y + r[..., 1, 2]*z,
r[..., 2, 0]*x + r[..., 2, 1]*y + r[..., 2, 2]*z,
], dim=-1)
return torch.stack(
[
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
],
dim=-1,
)
class T:
def __init__(self, rots, trans):
self.rots = rots
self.trans = trans
if(self.rots is None and self.trans is None):
if self.rots is None and self.trans is None:
raise ValueError("Only one of rots and trans can be None")
elif(self.rots is None):
elif self.rots is None:
self.rots = T.identity_rot(
self.trans.shape[:-1],
self.trans.dtype,
self.trans.device,
self.trans.shape[:-1],
self.trans.dtype,
self.trans.device,
self.trans.requires_grad,
)
elif(self.trans is None):
elif self.trans is None:
self.trans = T.identity_trans(
self.rots.shape[:-2],
self.rots.dtype,
self.rots.device,
self.rots.requires_grad
self.rots.shape[:-2],
self.rots.dtype,
self.rots.device,
self.rots.requires_grad,
)
if(self.rots.shape[-2:] != (3, 3) or
self.trans.shape[-1] != 3 or
self.rots.shape[:-2] != self.trans.shape[:-1]):
if (
self.rots.shape[-2:] != (3, 3)
or self.trans.shape[-1] != 3
or self.rots.shape[:-2] != self.trans.shape[:-1]
):
raise ValueError("Incorrectly shaped input")
def __getitem__(self, index):
if(type(index) != tuple):
if type(index) != tuple:
index = (index,)
return T(
self.rots[index + (slice(None), slice(None))],
self.trans[index + (slice(None),)]
self.rots[index + (slice(None), slice(None))],
self.trans[index + (slice(None),)],
)
def __eq__(self, obj):
return (
torch.all(self.rots == obj.rots) and
torch.all(self.trans == obj.trans)
return torch.all(self.rots == obj.rots) and torch.all(
self.trans == obj.trans
)
def __mul__(self, right):
......@@ -135,7 +166,7 @@ class T:
return T(rot_inv, -1 * trn_inv)
def unsqueeze(self, dim):
if(dim >= len(self.shape)):
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self.rots.unsqueeze(dim if dim >= 0 else dim - 2)
trans = self.trans.unsqueeze(dim if dim >= 0 else dim - 1)
......@@ -155,17 +186,14 @@ class T:
@staticmethod
def identity_trans(shape, dtype, device, requires_grad):
trans = torch.zeros(
(*shape, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad
)
(*shape, 3), dtype=dtype, device=device, requires_grad=requires_grad
)
return trans
@staticmethod
def identity(shape, dtype, device, requires_grad=True):
return T(
T.identity_rot(shape, dtype, device, requires_grad),
T.identity_rot(shape, dtype, device, requires_grad),
T.identity_trans(shape, dtype, device, requires_grad),
)
......@@ -191,7 +219,7 @@ class T:
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
origin = torch.unbind(origin, dim=-1)
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
......@@ -209,35 +237,31 @@ class T:
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3))
return T(rots, torch.stack(origin, dim=-1))
@staticmethod
def concat(ts, dim):
rots = torch.cat(
[t.rots for t in ts],
dim=dim if dim >= 0 else dim - 2
)
rots = torch.cat([t.rots for t in ts], dim=dim if dim >= 0 else dim - 2)
trans = torch.cat(
[t.trans for t in ts],
dim=dim if dim >= 0 else dim - 1
[t.trans for t in ts], dim=dim if dim >= 0 else dim - 1
)
return T(rots, trans)
def map_tensor_fn(self, fn):
"""
Apply a function that takes a tensor as its only argument to the
rotations and translations, treating the final two/one
dimension(s), respectively, as batch dimensions.
"""
Apply a function that takes a tensor as its only argument to the
rotations and translations, treating the final two/one
dimension(s), respectively, as batch dimensions.
E.g.: Given t, an instance of T of shape [N, M], this function can
be used to sum out the second dimension thereof as follows:
E.g.: Given t, an instance of T of shape [N, M], this function can
be used to sum out the second dimension thereof as follows:
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
The resulting object has rotations of shape [N, 3, 3] and
translations of shape [N, 3]
The resulting object has rotations of shape [N, 3, 3] and
translations of shape [N, 3]
"""
rots = self.rots.view(*self.rots.shape[:-2], 9)
rots = torch.stack(list(map(fn, torch.unbind(rots, -1))), dim=-1)
......@@ -260,7 +284,7 @@ class T:
c_xyz = c_xyz + translation
c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + c_x**2 + c_y**2)
norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2)
sin_c1 = -c_y / norm
cos_c1 = c_x / norm
zeros = sin_c1.new_zeros(sin_c1.shape)
......@@ -273,9 +297,9 @@ class T:
c1_rots[..., 1, 1] = cos_c1
c1_rots[..., 2, 2] = 1
norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2)
norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2)
sin_c2 = c_z / norm
cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm
cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
c2_rots[..., 0, 0] = cos_c2
......@@ -288,14 +312,14 @@ class T:
n_xyz = rot_vec_mul(c_rots, n_xyz)
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + n_y**2 + n_z**2)
norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2)
sin_n = -n_z / norm
cos_n = n_y / norm
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
n_rots[..., 0, 0] = 1
n_rots[..., 1, 1] = cos_n
n_rots[..., 1, 2] = -1 * sin_n
n_rots[..., 1, 2] = -1 * sin_n
n_rots[..., 2, 1] = sin_n
n_rots[..., 2, 2] = cos_n
......@@ -309,10 +333,11 @@ class T:
def cuda(self):
return T(self.rots.cuda(), self.trans.cuda())
_quat_elements = ['a', 'b', 'c', 'd']
_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key:ind for ind, key in enumerate(_qtr_keys)}
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs):
mat = torch.zeros((4, 4))
......@@ -323,20 +348,20 @@ def _to_mat(pairs):
return mat
_qtr_mat = np.zeros((4, 4, 3, 3))
_qtr_mat[..., 0, 0] = _to_mat([('aa', 1), ('bb', 1), ('cc', -1), ('dd', -1)])
_qtr_mat[..., 0, 1] = _to_mat([('bc', 2), ('ad', -2)])
_qtr_mat[..., 0, 2] = _to_mat([('bd', 2), ('ac', 2)])
_qtr_mat[..., 1, 0] = _to_mat([('bc', 2), ('ad', 2)])
_qtr_mat[..., 1, 1] = _to_mat([('aa', 1), ('bb', -1), ('cc', 1), ('dd', -1)])
_qtr_mat[..., 1, 2] = _to_mat([('cd', 2), ('ab', -2)])
_qtr_mat[..., 2, 0] = _to_mat([('bd', 2), ('ac', -2)])
_qtr_mat[..., 2, 1] = _to_mat([('cd', 2), ('ab', 2)])
_qtr_mat[..., 2, 2] = _to_mat([('aa', 1), ('bb', -1), ('cc', -1), ('dd', 1)])
def quat_to_rot(
quat # [*, 4]
):
_qtr_mat[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_qtr_mat[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_qtr_mat[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_qtr_mat[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_qtr_mat[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_qtr_mat[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_qtr_mat[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_qtr_mat[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_qtr_mat[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
def quat_to_rot(quat): # [*, 4]
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
......@@ -350,6 +375,7 @@ def quat_to_rot(
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
def affine_vector_to_4x4(vector):
quats = vector[..., :4]
trans = vector[..., 4:]
......
......@@ -20,31 +20,32 @@ from typing import Any, Tuple, List, Callable
BLOCK_ARG = Any
BLOCK_ARGS = List[BLOCK_ARG]
def checkpoint_blocks(
blocks: List[Callable],
args: BLOCK_ARGS,
blocks: List[Callable],
args: BLOCK_ARGS,
blocks_per_ckpt: int,
) -> BLOCK_ARGS:
"""
Chunk a list of blocks and run each chunk with activation
checkpointing. We define a "block" as a callable whose only inputs are
the outputs of the previous block.
Chunk a list of blocks and run each chunk with activation
checkpointing. We define a "block" as a callable whose only inputs are
the outputs of the previous block.
This function assumes that deepspeed has already been initialized.
This function assumes that deepspeed has already been initialized.
Implements Subsection 1.11.8
Implements Subsection 1.11.8
Args:
blocks:
List of blocks
args:
Tuple of arguments for the first block.
blocks_per_ckpt:
Size of each chunk. A higher value corresponds to higher memory
consumption but fewer checkpoints. If None, no checkpointing is
performed.
Returns:
The output of the final block
Args:
blocks:
List of blocks
args:
Tuple of arguments for the first block.
blocks_per_ckpt:
Size of each chunk. A higher value corresponds to higher memory
consumption but fewer checkpoints. If None, no checkpointing is
performed.
Returns:
The output of the final block
"""
def wrap(a):
......@@ -58,19 +59,20 @@ def checkpoint_blocks(
def chunker(s, e):
def exec_sliced(*a):
return exec(blocks[s:e], a)
return exec_sliced
# Avoids mishaps when the blocks take just one argument
args = wrap(args)
if(blocks_per_ckpt is None):
if blocks_per_ckpt is None:
return exec(blocks, args)
elif(blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks)):
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
#args = checkpoint(chunker(s, e), *args)
# args = checkpoint(chunker(s, e), *args)
args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args)
......
......@@ -3,26 +3,28 @@ import copy
import torch
import torch.nn as nn
class ExponentialMovingAverage:
"""
Maintains moving averages of parameters with exponential decay
"""
Maintains moving averages of parameters with exponential decay
At each step, the stored copy `copy` of each parameter `param` is
updated as follows:
At each step, the stored copy `copy` of each parameter `param` is
updated as follows:
`copy = decay * copy + (1 - decay) * param`
`copy = decay * copy + (1 - decay) * param`
where `decay` is an attribute of the ExponentialMovingAverage object.
where `decay` is an attribute of the ExponentialMovingAverage object.
"""
def __init__(self, model: nn.Module, decay: float):
"""
Args:
model:
A torch.nn.Module whose parameters are to be tracked
decay:
A value (usually close to 1.) by which updates are
weighted as part of the above formula
"""
Args:
model:
A torch.nn.Module whose parameters are to be tracked
decay:
A value (usually close to 1.) by which updates are
weighted as part of the above formula
"""
super(ExponentialMovingAverage, self).__init__()
self.params = copy.deepcopy(model.state_dict())
......@@ -32,27 +34,29 @@ class ExponentialMovingAverage:
with torch.no_grad():
for k, v in update.items():
stored = state_dict[k]
if(not isinstance(v, torch.Tensor)):
if not isinstance(v, torch.Tensor):
self._update_state_dict_(v, stored)
else:
diff = stored - v
diff *= (1 - self.decay)
diff *= 1 - self.decay
stored -= diff
def update(self, model: torch.nn.Module) -> None:
"""
Updates the stored parameters using the state dict of the provided
module. The module should have the same structure as that used to
initialize the ExponentialMovingAverage object.
Updates the stored parameters using the state dict of the provided
module. The module should have the same structure as that used to
initialize the ExponentialMovingAverage object.
"""
self._update_state_dict_(model.state_dict(), self.params)
self._update_state_dict_(model.state_dict(), self.params)
def load_state_dict(self, state_dict: OrderedDict) -> None:
self.params = state_dict["params"]
self.decay = state_dict["decay"]
def state_dict(self) -> OrderedDict:
return OrderedDict({
"params": self.params,
"decay": self.decay,
})
return OrderedDict(
{
"params": self.params,
"decay": self.decay,
}
)
......@@ -18,10 +18,10 @@ import torch
import torch.nn as nn
from typing import Dict
import openfold.np.residue_constants as rc
import openfold.np.residue_constants as rc
from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import (
batched_gather,
batched_gather,
one_hot,
tree_map,
tensor_tree_map,
......@@ -29,16 +29,16 @@ from openfold.utils.tensor_utils import (
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = (aatype == rc.restype_order['G'])
ca_idx = rc.atom_order['CA']
cb_idx = rc.atom_order['CB']
is_gly = aatype == rc.restype_order["G"]
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :]
all_atom_positions[..., cb_idx, :],
)
if(all_atom_masks is not None):
if all_atom_masks is not None:
pseudo_beta_mask = torch.where(
is_gly,
all_atom_masks[..., ca_idx],
......@@ -65,9 +65,9 @@ def atom14_to_atom37(atom14, batch):
def build_template_angle_feat(template_feats):
template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = (
template_feats["template_alt_torsion_angles_sin_cos"]
)
alt_torsion_angles_sin_cos = template_feats[
"template_alt_torsion_angles_sin_cos"
]
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
template_angle_feat = torch.cat(
[
......@@ -79,21 +79,24 @@ def build_template_angle_feat(template_feats):
*alt_torsion_angles_sin_cos.shape[:-2], 14
),
torsion_angles_mask,
],
],
dim=-1,
)
return template_angle_feat
def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8):
def build_template_pair_feat(
batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8
):
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
# Compute distogram (this seems to differ slightly from Alg. 5)
tpb = batch["template_pseudo_beta"]
dgram = torch.sum(
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True)
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[:-1], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
......@@ -101,7 +104,8 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot(
batch["template_aatype"], rc.restype_num + 2,
batch["template_aatype"],
rc.restype_num + 2,
)
n_res = batch["template_aatype"].shape[-1]
......@@ -116,7 +120,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
)
)
n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']]
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
# TODO: Consider running this in double precision
affines = T.make_transform_from_reference(
n_xyz=batch["template_all_atom_positions"][..., n, :],
......@@ -127,10 +131,8 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
points = affines.get_trans()[..., None, :, :]
affine_vec = affines[..., None].invert_apply(points)
inv_distance_scalar = torch.rsqrt(
eps + torch.sum(affine_vec ** 2, dim=-1)
)
inv_distance_scalar = torch.rsqrt(eps + torch.sum(affine_vec ** 2, dim=-1))
t_aa_masks = batch["template_all_atom_mask"]
template_mask = (
......@@ -139,10 +141,10 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = (affine_vec * inv_distance_scalar[..., None])
unit_vector = affine_vec * inv_distance_scalar[..., None]
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None])
act = torch.cat(to_concat, dim=-1)
act = act * template_mask_2d[..., None]
......@@ -161,55 +163,62 @@ def build_extra_msa_feat(batch):
# adapted from model/tf/data_transforms.py
def build_msa_feat(batch):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets.
has_break = batch["between_segment_residues"]
aatype_1hot = nn.functional.one_hot(batch['aatype'], num_classes=21)
target_feat = [
has_break.unsqueeze(-1),
aatype_1hot, # Everyone gets the original sequence.
]
msa_1hot = nn.functional.one_hot(batch['msa'], num_classes=23)
has_deletion = batch["deletion_matrix"]
deletion_value = torch.atan(batch['deletion_matrix'] / 3.) * (2. / math.pi)
msa_feat = [
msa_1hot,
has_deletion.unsqueeze(-1),
deletion_value.unsqueeze(-1),
]
if 'cluster_profile' in batch:
deletion_mean_value = (
tf.atan(batch['cluster_deletion_mean'] / 3.) * (2. / np.pi))
msa_feat.extend([
batch['cluster_profile'],
tf.expand_dims(deletion_mean_value, axis=-1),
])
if 'extra_deletion_matrix' in protein:
batch['extra_has_deletion'] = tf.clip_by_value(
batch['extra_deletion_matrix'], 0., 1.)
batch['extra_deletion_value'] = tf.atan(
batch['extra_deletion_matrix'] / 3.) * (2. / np.pi)
batch['msa_feat'] = torch.cat(msa_feat, dim=-1)
batch['target_feat'] = torch.cat(target_feat, dim=-1)
return batch
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets.
has_break = batch["between_segment_residues"]
aatype_1hot = nn.functional.one_hot(batch["aatype"], num_classes=21)
target_feat = [
has_break.unsqueeze(-1),
aatype_1hot, # Everyone gets the original sequence.
]
msa_1hot = nn.functional.one_hot(batch["msa"], num_classes=23)
has_deletion = batch["deletion_matrix"]
deletion_value = torch.atan(batch["deletion_matrix"] / 3.0) * (
2.0 / math.pi
)
msa_feat = [
msa_1hot,
has_deletion.unsqueeze(-1),
deletion_value.unsqueeze(-1),
]
if "cluster_profile" in batch:
deletion_mean_value = tf.atan(batch["cluster_deletion_mean"] / 3.0) * (
2.0 / np.pi
)
msa_feat.extend(
[
batch["cluster_profile"],
tf.expand_dims(deletion_mean_value, axis=-1),
]
)
if "extra_deletion_matrix" in protein:
batch["extra_has_deletion"] = tf.clip_by_value(
batch["extra_deletion_matrix"], 0.0, 1.0
)
batch["extra_deletion_value"] = tf.atan(
batch["extra_deletion_matrix"] / 3.0
) * (2.0 / np.pi)
batch["msa_feat"] = torch.cat(msa_feat, dim=-1)
batch["target_feat"] = torch.cat(target_feat, dim=-1)
return batch
def torsion_angles_to_frames(
t: T,
alpha: torch.Tensor,
aatype: torch.Tensor,
t: T,
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
):
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
......@@ -217,12 +226,9 @@ def torsion_angles_to_frames(
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
# [*, N, 8, 2]
alpha = torch.cat(
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha],
dim=-2
)
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
......@@ -233,7 +239,7 @@ def torsion_angles_to_frames(
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_t.rots.shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
......@@ -253,12 +259,14 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = T.concat([
all_frames_to_bb = T.concat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
], dim=-1,
],
dim=-1,
)
all_frames_to_global = t[..., None].compose(all_frames_to_bb)
......@@ -274,20 +282,21 @@ def frames_and_literature_positions_to_atom14_pos(
atom_mask,
lit_positions,
):
# [*, N, 14, 4, 4]
# [*, N, 14, 4, 4]
default_4x4 = default_frames[aatype, ...]
# [*, N, 14]
group_mask = group_idx[aatype, ...]
# [*, N, 14, 8]
group_mask = nn.functional.one_hot(
group_mask, num_classes=default_frames.shape[-3],
group_mask,
num_classes=default_frames.shape[-3],
)
# [*, N, 14, 8]
t_atoms_to_global = t[..., None, :] * group_mask
# [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
lambda x: torch.sum(x, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment