Unverified Commit bb3f51e5 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #405 from aqlaboratory/multimer

Full multimer merge
parents ce211367 c33a0bd6
......@@ -18,10 +18,11 @@ import math
import numpy as np
import torch
import torch.nn as nn
from typing import Dict
from typing import Dict, Union
from openfold.np import protein
import openfold.np.residue_constants as rc
from openfold.utils.geometry import rigid_matrix_vector, rotation_matrix, vector
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
batched_gather,
......@@ -89,6 +90,23 @@ def build_template_angle_feat(template_feats):
return template_angle_feat
def dgram_from_positions(
pos: torch.Tensor,
min_bin: float = 3.25,
max_bin: float = 50.75,
no_bins: float = 39,
inf: float = 1e8,
):
dgram = torch.sum(
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
return dgram
def build_template_pair_feat(
batch,
min_bin, max_bin, no_bins,
......@@ -100,12 +118,7 @@ def build_template_pair_feat(
# Compute distogram (this seems to differ slightly from Alg. 5)
tpb = batch["template_pseudo_beta"]
dgram = torch.sum(
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
dgram = dgram_from_positions(tpb, min_bin, max_bin, no_bins, inf)
to_concat = [dgram, template_mask_2d[..., None]]
......@@ -170,18 +183,21 @@ def build_extra_msa_feat(batch):
def torsion_angles_to_frames(
r: Rigid,
r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
alpha: torch.Tensor,
aatype: torch.Tensor,
rrgdf: torch.Tensor,
):
rigid_type = type(r)
# [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_r = r.from_tensor_4x4(default_4x4)
default_r = rigid_type.from_tensor_4x4(default_4x4)
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1
......@@ -201,14 +217,13 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots = alpha.new_zeros(default_r.shape + (4, 4))
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_rots[..., 2, 1:3] = alpha
all_rots = rigid_type.from_tensor_4x4(all_rots)
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
......@@ -220,7 +235,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = Rigid.cat(
all_frames_to_bb = rigid_type.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
......@@ -236,7 +251,7 @@ def torsion_angles_to_frames(
def frames_and_literature_positions_to_atom14_pos(
r: Rigid,
r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
aatype: torch.Tensor,
default_frames,
group_idx,
......@@ -263,7 +278,7 @@ def frames_and_literature_positions_to_atom14_pos(
lambda x: torch.sum(x, dim=-1)
)
# [*, N, 14, 1]
# [*, N, 14]
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
# [*, N, 14, 3]
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from openfold.utils.geometry import rigid_matrix_vector
from openfold.utils.geometry import rotation_matrix
from openfold.utils.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
class QuatRigid(nn.Module):
def __init__(self, c_hidden, full_quat):
super().__init__()
self.full_quat = full_quat
if self.full_quat:
rigid_dim = 7
else:
rigid_dim = 6
self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32)
def forward(self, activations: torch.Tensor) -> Rigid3Array:
# NOTE: During training, this needs to be run in higher precision
rigid_flat = self.linear(activations)
rigid_flat = torch.unbind(rigid_flat, dim=-1)
if(self.full_quat):
qw, qx, qy, qz = rigid_flat[:4]
translation = rigid_flat[4:]
else:
qx, qy, qz = rigid_flat[:3]
qw = torch.ones_like(qx)
translation = rigid_flat[3:]
rotation = Rot3Array.from_quaternion(
qw, qx, qy, qz, normalize=True,
)
translation = Vec3Array(*translation)
return Rigid3Array(rotation, translation)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from __future__ import annotations
import dataclasses
from typing import Union, List
import torch
from openfold.utils.geometry import rotation_matrix
from openfold.utils.geometry import vector
Float = Union[float, torch.Tensor]
@dataclasses.dataclass(frozen=True)
class Rigid3Array:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation: rotation_matrix.Rot3Array
translation: vector.Vec3Array
def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
new_rotation = self.rotation @ other.rotation # __matmul__
new_translation = self.apply_to_point(other.translation)
return Rigid3Array(new_rotation, new_translation)
def __getitem__(self, index) -> Rigid3Array:
return Rigid3Array(
self.rotation[index],
self.translation[index],
)
def __mul__(self, other: torch.Tensor) -> Rigid3Array:
return Rigid3Array(
self.rotation * other,
self.translation * other,
)
def map_tensor_fn(self, fn) -> Rigid3Array:
return Rigid3Array(
self.rotation.map_tensor_fn(fn),
self.translation.map_tensor_fn(fn),
)
def inverse(self) -> Rigid3Array:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation = self.rotation.inverse()
inv_translation = inv_rotation.apply_to_point(-self.translation)
return Rigid3Array(inv_rotation, inv_translation)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor()
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
def invert_apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor()
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
return Rigid3Array(rot, self.translation.clone())
def compose(self, other_rigid):
return self @ other_rigid
def unsqueeze(self, dim: int):
return Rigid3Array(
self.rotation.unsqueeze(dim),
self.translation.unsqueeze(dim),
)
@property
def shape(self) -> torch.Size:
return self.rotation.xx.shape
@property
def dtype(self) -> torch.dtype:
return self.rotation.xx.dtype
@property
def device(self) -> torch.device:
return self.rotation.xx.device
@classmethod
def identity(cls, shape, device) -> Rigid3Array:
"""Return identity Rigid3Array of given shape."""
return cls(
rotation_matrix.Rot3Array.identity(shape, device),
vector.Vec3Array.zeros(shape, device)
)
@classmethod
def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
return cls(
rotation_matrix.Rot3Array.cat(
[r.rotation for r in rigids], dim=dim
),
vector.Vec3Array.cat(
[r.translation for r in rigids], dim=dim
),
)
def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor)
def to_tensor(self) -> torch.Tensor:
rot_array = self.rotation.to_tensor()
vec_array = self.translation.to_tensor()
array = torch.zeros(
rot_array.shape[:-2] + (4, 4),
device=rot_array.device,
dtype=rot_array.dtype
)
array[..., :3, :3] = rot_array
array[..., :3, 3] = vec_array
array[..., 3, 3] = 1.
return array
def to_tensor_4x4(self) -> torch.Tensor:
return self.to_tensor()
def reshape(self, new_shape) -> Rigid3Array:
rots = self.rotation.reshape(new_shape)
trans = self.translation.reshape(new_shape)
return Rigid3Array(rots, trans)
def stop_rot_gradient(self) -> Rigid3Array:
return Rigid3Array(
self.rotation.stop_gradient(),
self.translation,
)
@classmethod
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(
array[..., :3, :3],
)
vec = vector.Vec3Array.from_array(array[..., :3, 3])
return cls(rot, vec)
@classmethod
def from_tensor_4x4(cls, array):
return cls.from_array(array)
@classmethod
def from_array4x4(cls, array: torch.tensor) -> Rigid3Array:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation = rotation_matrix.Rot3Array(
array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
)
translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
)
return cls(rotation, translation)
def cuda(self) -> Rigid3Array:
return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda())
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""
from __future__ import annotations
import dataclasses
from typing import List
import torch
from openfold.utils.geometry import utils
from openfold.utils.geometry import vector
from openfold.utils.tensor_utils import tensor_tree_map
COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
@dataclasses.dataclass(frozen=True)
class Rot3Array:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
xy: torch.Tensor
xz: torch.Tensor
yx: torch.Tensor
yy: torch.Tensor
yz: torch.Tensor
zx: torch.Tensor
zy: torch.Tensor
zz: torch.Tensor
__array_ufunc__ = None
def __getitem__(self, index):
field_names = utils.get_field_names(Rot3Array)
return Rot3Array(
**{
name: getattr(self, name)[index]
for name in field_names
}
)
def __mul__(self, other: torch.Tensor):
field_names = utils.get_field_names(Rot3Array)
return Rot3Array(
**{
name: getattr(self, name) * other
for name in field_names
}
)
def __matmul__(self, other: Rot3Array) -> Rot3Array:
"""Composes two Rot3Arrays."""
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
def map_tensor_fn(self, fn) -> Rot3Array:
field_names = utils.get_field_names(Rot3Array)
return Rot3Array(
**{
name: fn(getattr(self, name))
for name in field_names
}
)
def inverse(self) -> Rot3Array:
"""Returns inverse of Rot3Array."""
return Rot3Array(
self.xx, self.yx, self.zx,
self.xy, self.yy, self.zy,
self.xz, self.yz, self.zz
)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies Rot3Array to point."""
return vector.Vec3Array(
self.xx * point.x + self.xy * point.y + self.xz * point.z,
self.yx * point.x + self.yy * point.y + self.yz * point.z,
self.zx * point.x + self.zy * point.y + self.zz * point.z
)
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Applies inverse Rot3Array to point."""
return self.inverse().apply_to_point(point)
def unsqueeze(self, dim: int):
return Rot3Array(
*tensor_tree_map(
lambda t: t.unsqueeze(dim),
[getattr(self, c) for c in COMPONENTS]
)
)
def stop_gradient(self) -> Rot3Array:
return Rot3Array(
*[getattr(self, c).detach() for c in COMPONENTS]
)
@classmethod
def identity(cls, shape, device) -> Rot3Array:
"""Returns identity of given shape."""
ones = torch.ones(shape, dtype=torch.float32, device=device)
zeros = torch.zeros(shape, dtype=torch.float32, device=device)
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
@classmethod
def from_two_vectors(
cls, e0: vector.Vec3Array,
e1: vector.Vec3Array
) -> Rot3Array:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0 = e0.normalized()
# make e1 perpendicular to e0.
c = e1.dot(e0)
e1 = (e1 - c * e0).normalized()
# Compute e2 as cross product of e0 and e1.
e2 = e0.cross(e1)
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
@classmethod
def from_array(cls, array: torch.Tensor) -> Rot3Array:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
rows = torch.unbind(array, dim=-2)
rc = [torch.unbind(e, dim=-1) for e in rows]
return cls(*[e for row in rc for e in row])
def to_tensor(self) -> torch.Tensor:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return torch.stack(
[
torch.stack([self.xx, self.xy, self.xz], dim=-1),
torch.stack([self.yx, self.yy, self.yz], dim=-1),
torch.stack([self.zx, self.zy, self.zz], dim=-1)
],
dim=-2
)
@classmethod
def from_quaternion(cls,
w: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
normalize: bool = True,
eps: float = 1e-6
) -> Rot3Array:
"""Construct Rot3Array from components of quaternion."""
if normalize:
inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps))
w = w * inv_norm
x = x * inv_norm
y = y * inv_norm
z = z * inv_norm
xx = 1.0 - 2.0 * (y ** 2 + z ** 2)
xy = 2.0 * (x * y - w * z)
xz = 2.0 * (x * z + w * y)
yx = 2.0 * (x * y + w * z)
yy = 1.0 - 2.0 * (x ** 2 + z ** 2)
yz = 2.0 * (y * z - w * x)
zx = 2.0 * (x * z - w * y)
zy = 2.0 * (y * z + w * x)
zz = 1.0 - 2.0 * (x ** 2 + y ** 2)
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
def reshape(self, new_shape):
field_names = utils.get_field_names(Rot3Array)
reshape_fn = lambda t: t.reshape(new_shape)
return Rot3Array(
**{
name: reshape_fn(getattr(self, name))
for name in field_names
}
)
@classmethod
def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array:
field_names = utils.get_field_names(Rot3Array)
cat_fn = lambda l: torch.cat(l, dim=dim)
return cls(
**{
name: cat_fn([getattr(r, name) for r in rots])
for name in field_names
}
)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utils for tests."""
import dataclasses
import torch
from openfold.utils.geometry import rigid_matrix_vector
from openfold.utils.geometry import rotation_matrix
from openfold.utils.geometry import vector
def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array,
matrix2: rotation_matrix.Rot3Array):
for field in dataclasses.fields(rotation_matrix.Rot3Array):
field = field.name
assert torch.equal(
getattr(matrix1, field), getattr(matrix2, field))
def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array,
mat2: rotation_matrix.Rot3Array):
assert torch.allclose(mat1.to_tensor(), mat2.to_tensor(), atol=1e-6)
def assert_array_equal_to_rotation_matrix(array: torch.Tensor,
matrix: rotation_matrix.Rot3Array):
"""Check that array and Matrix match."""
assert torch.equal(matrix.xx, array[..., 0, 0])
assert torch.equal(matrix.xy, array[..., 0, 1])
assert torch.equal(matrix.xz, array[..., 0, 2])
assert torch.equal(matrix.yx, array[..., 1, 0])
assert torch.equal(matrix.yy, array[..., 1, 1])
assert torch.equal(matrix.yz, array[..., 1, 2])
assert torch.equal(matrix.zx, array[..., 2, 0])
assert torch.equal(matrix.zy, array[..., 2, 1])
assert torch.equal(matrix.zz, array[..., 2, 2])
def assert_array_close_to_rotation_matrix(array: torch.Tensor,
matrix: rotation_matrix.Rot3Array):
assert torch.allclose(matrix.to_tensor(), array, atol=1e-6)
def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
assert torch.equal(vec1.x, vec2.x)
assert torch.equal(vec1.y, vec2.y)
assert torch.equal(vec1.z, vec2.z)
def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
assert torch.allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.)
assert torch.allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.)
assert torch.allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.)
def assert_array_close_to_vector(array: torch.Tensor, vec: vector.Vec3Array):
assert torch.allclose(vec.to_tensor(), array, atol=1e-6, rtol=0.)
def assert_array_equal_to_vector(array: torch.Tensor, vec: vector.Vec3Array):
assert torch.equal(vec.to_tensor(), array)
def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
rigid2: rigid_matrix_vector.Rigid3Array):
assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
rigid2: rigid_matrix_vector.Rigid3Array):
assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array,
trans: vector.Vec3Array,
rigid: rigid_matrix_vector.Rigid3Array):
assert_rotation_matrix_equal(rot, rigid.rotation)
assert_vectors_equal(trans, rigid.translation)
def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array,
trans: vector.Vec3Array,
rigid: rigid_matrix_vector.Rigid3Array):
assert_rotation_matrix_close(rot, rigid.rotation)
assert_vectors_close(trans, rigid.translation)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for geometry library."""
import dataclasses
def get_field_names(cls):
fields = dataclasses.fields(cls)
field_names = [f.name for f in fields]
return field_names
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vec3Array Class."""
from __future__ import annotations
import dataclasses
from typing import Union, List
import torch
Float = Union[float, torch.Tensor]
@dataclasses.dataclass(frozen=True)
class Vec3Array:
x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
y: torch.Tensor
z: torch.Tensor
def __post_init__(self):
if hasattr(self.x, 'dtype'):
assert self.x.dtype == self.y.dtype
assert self.x.dtype == self.z.dtype
assert all([x == y for x, y in zip(self.x.shape, self.y.shape)])
assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
def __add__(self, other: Vec3Array) -> Vec3Array:
return Vec3Array(
self.x + other.x,
self.y + other.y,
self.z + other.z,
)
def __sub__(self, other: Vec3Array) -> Vec3Array:
return Vec3Array(
self.x - other.x,
self.y - other.y,
self.z - other.z,
)
def __mul__(self, other: Float) -> Vec3Array:
return Vec3Array(
self.x * other,
self.y * other,
self.z * other,
)
def __rmul__(self, other: Float) -> Vec3Array:
return self * other
def __truediv__(self, other: Float) -> Vec3Array:
return Vec3Array(
self.x / other,
self.y / other,
self.z / other,
)
def __neg__(self) -> Vec3Array:
return self * -1
def __pos__(self) -> Vec3Array:
return self * 1
def __getitem__(self, index) -> Vec3Array:
return Vec3Array(
self.x[index],
self.y[index],
self.z[index],
)
def __iter__(self):
return iter((self.x, self.y, self.z))
@property
def shape(self):
return self.x.shape
def map_tensor_fn(self, fn) -> Vec3Array:
return Vec3Array(
fn(self.x),
fn(self.y),
fn(self.z),
)
def cross(self, other: Vec3Array) -> Vec3Array:
"""Compute cross product between 'self' and 'other'."""
new_x = self.y * other.z - self.z * other.y
new_y = self.z * other.x - self.x * other.z
new_z = self.x * other.y - self.y * other.x
return Vec3Array(new_x, new_y, new_z)
def dot(self, other: Vec3Array) -> Float:
"""Compute dot product between 'self' and 'other'."""
return self.x * other.x + self.y * other.y + self.z * other.z
def norm(self, epsilon: float = 1e-6) -> Float:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2 = self.dot(self)
if epsilon:
norm2 = torch.clamp(norm2, min=epsilon**2)
return torch.sqrt(norm2)
def norm2(self):
return self.dot(self)
def normalized(self, epsilon: float = 1e-6) -> Vec3Array:
"""Return unit vector with optional clipping."""
return self / self.norm(epsilon)
def clone(self) -> Vec3Array:
return Vec3Array(
self.x.clone(),
self.y.clone(),
self.z.clone(),
)
def reshape(self, new_shape) -> Vec3Array:
x = self.x.reshape(new_shape)
y = self.y.reshape(new_shape)
z = self.z.reshape(new_shape)
return Vec3Array(x, y, z)
def sum(self, dim: int) -> Vec3Array:
return Vec3Array(
torch.sum(self.x, dim=dim),
torch.sum(self.y, dim=dim),
torch.sum(self.z, dim=dim),
)
def unsqueeze(self, dim: int):
return Vec3Array(
self.x.unsqueeze(dim),
self.y.unsqueeze(dim),
self.z.unsqueeze(dim),
)
@classmethod
def zeros(cls, shape, device="cpu"):
"""Return Vec3Array corresponding to zeros of given shape."""
return cls(
torch.zeros(shape, dtype=torch.float32, device=device),
torch.zeros(shape, dtype=torch.float32, device=device),
torch.zeros(shape, dtype=torch.float32, device=device)
)
def to_tensor(self) -> torch.Tensor:
return torch.stack([self.x, self.y, self.z], dim=-1)
@classmethod
def from_array(cls, tensor):
return cls(*torch.unbind(tensor, dim=-1))
@classmethod
def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array:
return cls(
torch.cat([v.x for v in vecs], dim=dim),
torch.cat([v.y for v in vecs], dim=dim),
torch.cat([v.z for v in vecs], dim=dim),
)
def square_euclidean_distance(
vec1: Vec3Array,
vec2: Vec3Array,
epsilon: float = 1e-6
) -> Float:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference = vec1 - vec2
distance = difference.dot(difference)
if epsilon:
distance = torch.clamp(distance, min=epsilon)
return distance
def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.dot(vector2)
def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.cross(vector2)
def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
return vector.norm(epsilon)
def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:
return vector.normalized(epsilon)
def euclidean_distance(
vec1: Vec3Array,
vec2: Vec3Array,
epsilon: float = 1e-6
) -> Float:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2)
distance = torch.sqrt(distance_sq)
return distance
def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array,
d: Vec3Array) -> Float:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1 = a - b
v2 = b - c
v3 = d - c
c1 = v1.cross(v2)
c2 = v3.cross(v2)
c3 = c2.cross(c1)
v2_mag = v2.norm()
return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2))
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from enum import Enum
from dataclasses import dataclass
from functools import partial
......@@ -27,7 +28,7 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
class ParamType(Enum):
LinearWeight = partial( # hack: partial prevents fns from becoming methods
lambda w: w.transpose(-1, -2)
lambda w: w.unsqueeze(-1) if len(w.shape) == 1 else w.transpose(-1, -2)
)
LinearWeightMHA = partial(
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
......@@ -39,6 +40,13 @@ class ParamType(Enum):
LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
LinearWeightMultimer = partial(
lambda w: w.unsqueeze(-1) if len(w.shape) == 1 else
w.reshape(w.shape[0], -1).transpose(-1, -2)
)
LinearBiasMultimer = partial(
lambda w: w.reshape(-1)
)
Other = partial(lambda w: w)
def __init__(self, fn):
......@@ -50,6 +58,7 @@ class Param:
param: Union[torch.Tensor, List[torch.Tensor]]
param_type: ParamType = ParamType.Other
stacked: bool = False
swap: bool = False
def process_translation_dict(d, top_layer=True):
......@@ -93,6 +102,7 @@ def stacked(param_dict_list, out=None):
param=[param.param for param in v],
param_type=v[0].param_type,
stacked=True,
swap=v[0].swap
)
out[k] = stacked_param
......@@ -114,6 +124,11 @@ def assign(translation_dict, orig_weights):
try:
weights = list(map(param_type.transformation, weights))
for p, w in zip(ref, weights):
if param.swap:
index = p.shape[0] // 2
p[:index].copy_(w[index:])
p[index:].copy_(w[:index])
else:
p.copy_(w)
except:
print(k)
......@@ -122,26 +137,44 @@ def assign(translation_dict, orig_weights):
raise
def generate_translation_dict(model, version):
def generate_translation_dict(model, version, is_multimer=False):
#######################
# Some templates
#######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearWeightMultimer = lambda l: (
Param(l, param_type=ParamType.LinearWeightMultimer)
)
LinearBiasMultimer = lambda l: (
Param(l, param_type=ParamType.LinearBiasMultimer)
)
LinearWeightSwap = lambda l: (Param(l, param_type=ParamType.LinearWeight, swap=True))
LinearBiasSwap = lambda l: (Param(l, swap=True))
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias),
}
LinearParamsMHA = lambda l: {
"weights": LinearWeightMHA(l.weight),
"bias": LinearBiasMHA(l.bias),
}
LinearParamsSwap = lambda l: {
"weights": LinearWeightSwap(l.weight),
"bias": LinearBiasSwap(l.bias),
}
LinearParamsMultimer = lambda l: {
"weights": LinearWeightMultimer(l.weight),
"bias": LinearBiasMultimer(l.bias),
}
LayerNormParams = lambda l: {
"scale": Param(l.weight),
"offset": Param(l.bias),
......@@ -178,31 +211,48 @@ def generate_translation_dict(model, version):
"attention": AttentionGatedParams(tri_att.mha),
}
TriMulOutParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_a_p),
"right_projection": LinearParams(tri_mul.linear_b_p),
"left_gate": LinearParams(tri_mul.linear_a_g),
"right_gate": LinearParams(tri_mul.linear_b_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
def TriMulOutParams(tri_mul, outgoing=True):
if re.fullmatch("^model_[1-5]_multimer_v3$", version):
lin_param_type = LinearParams if outgoing else LinearParamsSwap
d = {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": lin_param_type(tri_mul.linear_ab_p),
"gate": lin_param_type(tri_mul.linear_ab_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
}
else:
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
if outgoing:
left_projection = LinearParams(tri_mul.linear_a_p)
right_projection = LinearParams(tri_mul.linear_b_p)
left_gate = LinearParams(tri_mul.linear_a_g)
right_gate = LinearParams(tri_mul.linear_b_g)
else:
left_projection = LinearParams(tri_mul.linear_b_p)
right_projection = LinearParams(tri_mul.linear_a_p)
left_gate = LinearParams(tri_mul.linear_b_g)
right_gate = LinearParams(tri_mul.linear_a_g)
d = {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearParams(tri_mul.linear_b_p),
"right_projection": LinearParams(tri_mul.linear_a_p),
"left_gate": LinearParams(tri_mul.linear_b_g),
"right_gate": LinearParams(tri_mul.linear_a_g),
"left_projection": left_projection,
"right_projection": right_projection,
"left_gate": left_gate,
"right_gate": right_gate,
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
}
d.update({
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
})
return d
TriMulInParams = partial(TriMulOutParams, outgoing=False)
PairTransitionParams = lambda pt: {
"input_layer_norm": LayerNormParams(pt.layer_norm),
......@@ -236,8 +286,46 @@ def generate_translation_dict(model, version):
IPAParams = lambda ipa: {
"q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points),
"kv_point_local": LinearParams(ipa.linear_kv_points),
"q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local": LinearParams(ipa.linear_kv_points.linear),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
PointProjectionParams = lambda pp: {
"point_projection": LinearParamsMHA(
pp.linear,
),
}
IPAParamsMultimer = lambda ipa: {
"q_scalar_projection": {
"weights": LinearWeightMHA(
ipa.linear_q.weight,
),
},
"k_scalar_projection": {
"weights": LinearWeightMHA(
ipa.linear_k.weight,
),
},
"v_scalar_projection": {
"weights": LinearWeightMHA(
ipa.linear_v.weight,
),
},
"q_point_projection": PointProjectionParams(
ipa.linear_q_points
),
"k_point_projection": PointProjectionParams(
ipa.linear_k_points
),
"v_point_projection": PointProjectionParams(
ipa.linear_v_points
),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
......@@ -280,27 +368,29 @@ def generate_translation_dict(model, version):
b.msa_att_row
),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.core.msa_transition),
"msa_transition": MSATransitionParams(b.msa_transition),
"outer_product_mean":
OuterProductMeanParams(b.core.outer_product_mean),
OuterProductMeanParams(b.outer_product_mean),
"triangle_multiplication_outgoing":
TriMulOutParams(b.core.tri_mul_out),
TriMulOutParams(b.pair_stack.tri_mul_out),
"triangle_multiplication_incoming":
TriMulInParams(b.core.tri_mul_in),
TriMulInParams(b.pair_stack.tri_mul_in),
"triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start),
TriAttParams(b.pair_stack.tri_att_start),
"triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end),
TriAttParams(b.pair_stack.tri_att_end),
"pair_transition":
PairTransitionParams(b.core.pair_transition),
PairTransitionParams(b.pair_stack.pair_transition),
}
return d
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: {
"invariant_point_attention": IPAParams(sm.ipa),
def FoldIterationParams(sm):
d = {
"invariant_point_attention":
IPAParamsMultimer(sm.ipa) if is_multimer else IPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_1": LinearParams(sm.transition.layers[0].linear_2),
......@@ -309,25 +399,39 @@ def generate_translation_dict(model, version):
"affine_update": LinearParams(sm.bb_update.linear),
"rigid_sidechain": {
"input_projection": LinearParams(sm.angle_resnet.linear_in),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
"input_projection_1":
LinearParams(sm.angle_resnet.linear_initial),
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
"resblock1_1":
LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1":
LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles":
LinearParams(sm.angle_resnet.linear_out),
},
}
if(is_multimer):
d.pop("affine_update")
d["quat_rigid"] = {
"rigid": LinearParams(
sm.bb_update.linear
)
}
return d
############################
# translations dict overflow
############################
ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
if(not is_multimer):
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
......@@ -383,6 +487,64 @@ def generate_translation_dict(model, version):
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
else:
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"~_relative_encoding": {
"position_activations": LinearParams(
model.input_embedder.linear_relpos
),
},
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
no_templ = [
"model_3",
......@@ -394,48 +556,98 @@ def generate_translation_dict(model, version):
]
if version not in no_templ:
tps_blocks = model.template_pair_stack.blocks
tps_blocks = model.template_embedder.template_pair_stack.blocks
tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks]
)
if (not is_multimer):
template_param_dict = {
"template_embedding": {
"single_template_embedding": {
"embedding2d": LinearParams(
model.template_pair_embedder.linear
model.template_embedder.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm
model.template_embedder.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(model.template_pointwise_att.mha),
"attention": AttentionParams(model.template_embedder.template_pointwise_att.mha),
},
"template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1
model.template_embedder.template_single_embedder.linear_1
),
"template_projection": LinearParams(
model.template_angle_embedder.linear_2
model.template_embedder.template_single_embedder.linear_2
),
}
else:
temp_embedder = model.template_embedder
template_param_dict = {
"template_embedding": {
"single_template_embedding": {
"query_embedding_norm": LayerNormParams(
temp_embedder.template_pair_embedder.query_embedding_layer_norm
),
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParams(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_1
),
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParams(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParams(
temp_embedder.template_pair_embedder.y_linear
),
"template_pair_embedding_6": LinearParams(
temp_embedder.template_pair_embedder.z_linear
),
"template_pair_embedding_7": LinearParams(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
temp_embedder.template_pair_embedder.query_embedding_linear
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
temp_embedder.template_pair_stack.layer_norm
),
},
"output_linear": LinearParams(
temp_embedder.linear_t
),
},
"template_projection": LinearParams(
temp_embedder.template_single_embedder.template_projector,
),
"template_single_embedding": LinearParams(
temp_embedder.template_single_embedder.template_single_embedder,
),
}
translations["evoformer"].update(template_param_dict)
if "_ptm" in version:
if is_multimer or "_ptm" in version:
translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear)
}
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = generate_translation_dict(model, version)
translations = generate_translation_dict(model, version, is_multimer=("multimer" in version))
# Flatten keys and insert missing key prefixes
flat = process_translation_dict(translations)
......@@ -453,3 +665,45 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# Set weights
assign(flat, data)
def convert_deprecated_v1_keys(state_dict):
"""Update older OpenFold model weight names to match the current model code."""
replacements = {
'template_angle_embedder': 'template_single_embedder',
'core.msa_transition': 'msa_transition',
'core.outer_product_mean': 'outer_product_mean',
'core.tri_': 'pair_stack.tri_',
'core.pair_transition': 'pair_stack.pair_transition',
'ipa.linear_q_points': 'ipa.linear_q_points.linear',
'ipa.linear_kv_points': 'ipa.linear_kv_points.linear'
}
convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
converted_state_dict = {}
for key, value in state_dict.items():
# For each match, look-up replacement value in the dictionary
new_key = convert_key_re.sub(lambda m: replacements[m.group()], key)
# Add prefix for template modules
if new_key.startswith('template'):
new_key = f'template_embedder.{new_key}'
converted_state_dict[new_key] = value
return converted_state_dict
def import_openfold_weights_(model, state_dict):
"""
Import model weights. Several parts of the model were refactored in the process
of adding support for Multimer. The state dicts of older models are translated
to match the refactored model code.
"""
try:
model.load_state_dict(state_dict)
except RuntimeError:
converted_state_dict = convert_deprecated_v1_keys(state_dict)
model.load_state_dict(converted_state_dict)
......@@ -13,25 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import logging
import ml_collections
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from typing import Dict, Optional, Tuple
from openfold.np import residue_constants
from openfold.utils import feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.vector import Vec3Array, euclidean_distance
from openfold.utils.all_atom_multimer import get_rc_tensor
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
masked_mean,
permute_final_dims,
batched_gather,
)
import logging
from openfold.utils.tensor_utils import tensor_tree_map
logger = logging.getLogger(__name__)
def softmax_cross_entropy(logits, labels):
......@@ -87,6 +87,7 @@ def compute_fape(
target_positions: torch.Tensor,
positions_mask: torch.Tensor,
length_scale: float,
pair_mask: Optional[torch.Tensor] = None,
l1_clamp_distance: Optional[float] = None,
eps=1e-8,
) -> torch.Tensor:
......@@ -108,6 +109,9 @@ def compute_fape(
[*, N_pts] positions mask
length_scale:
Length scale by which the loss is divided
pair_mask:
[*, N_frames, N_pts] mask to use for
separating intra- from inter-chain losses.
l1_clamp_distance:
Cutoff above which distance errors are disregarded
eps:
......@@ -134,6 +138,15 @@ def compute_fape(
normed_error = normed_error * frames_mask[..., None]
normed_error = normed_error * positions_mask[..., None, :]
if pair_mask is not None:
normed_error = normed_error * pair_mask
normed_error = torch.sum(normed_error, dim=(-1, -2))
mask = frames_mask[..., None] * positions_mask[..., None, :] * pair_mask
norm_factor = torch.sum(mask, dim=(-2, -1))
normed_error = normed_error / (eps + norm_factor)
else:
# FP16-friendly averaging. Roughly equivalent to:
#
# norm_factor = (
......@@ -157,13 +170,19 @@ def backbone_loss(
backbone_rigid_tensor: torch.Tensor,
backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor,
pair_mask: Optional[torch.Tensor] = None,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0,
loss_unit_distance: float = 10.0,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
### need to check if the traj belongs to 4*4 matrix or a tensor_7
if traj.shape[-1] == 7:
pred_aff = Rigid.from_tensor_7(traj)
elif traj.shape[-1] == 4:
pred_aff = Rigid.from_tensor_4x4(traj)
pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(),
......@@ -184,6 +203,7 @@ def backbone_loss(
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_rigid_mask[None],
pair_mask=pair_mask,
l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance,
eps=eps,
......@@ -196,6 +216,7 @@ def backbone_loss(
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_rigid_mask[None],
pair_mask=pair_mask,
l1_clamp_distance=None,
length_scale=loss_unit_distance,
eps=eps,
......@@ -253,6 +274,7 @@ def sidechain_loss(
sidechain_atom_pos,
renamed_atom14_gt_positions,
renamed_atom14_gt_exists,
pair_mask=None,
l1_clamp_distance=clamp_distance,
length_scale=length_scale,
eps=eps,
......@@ -266,10 +288,28 @@ def fape_loss(
batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict,
) -> torch.Tensor:
traj = out["sm"]["frames"]
asym_id = batch.get("asym_id")
if asym_id is not None:
intra_chain_mask = (asym_id[..., None] == asym_id[..., None, :]).to(dtype=traj.dtype)
intra_chain_bb_loss = backbone_loss(
traj=traj,
pair_mask=intra_chain_mask,
**{**batch, **config.intra_chain_backbone},
)
interface_bb_loss = backbone_loss(
traj=traj,
pair_mask=1. - intra_chain_mask,
**{**batch, **config.interface_backbone},
)
weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight
+ interface_bb_loss * config.interface_backbone.weight)
else:
bb_loss = backbone_loss(
traj=out["sm"]["frames"],
traj=traj,
**{**batch, **config.backbone},
)
weighted_bb_loss = bb_loss * config.backbone.weight
sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"],
......@@ -277,7 +317,7 @@ def fape_loss(
**{**batch, **config.sidechain},
)
loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
loss = weighted_bb_loss + config.sidechain.weight * sc_loss
# Average over the batch dimension
loss = torch.mean(loss)
......@@ -452,7 +492,7 @@ def lddt_ca(
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
return lddt(
all_atom_pred_pos,
......@@ -482,7 +522,7 @@ def lddt_loss(
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
score = lddt(
all_atom_pred_pos,
......@@ -492,8 +532,11 @@ def lddt_loss(
eps=eps
)
score = score.detach()
# TODO: Remove after initial pipeline testing
score = torch.nan_to_num(score, nan=torch.nanmean(score))
score[score < 0] = 0
score = score.detach()
bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
lddt_ca_one_hot = torch.nn.functional.one_hot(
......@@ -627,6 +670,8 @@ def compute_predicted_aligned_error(
def compute_tm(
logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
asym_id: Optional[torch.Tensor] = None,
interface: bool = False,
max_bin: int = 31,
no_bins: int = 64,
eps: float = 1e-8,
......@@ -649,7 +694,22 @@ def compute_tm(
tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
n = residue_weights.shape[-1]
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
if interface and (asym_id is not None):
if len(asym_id.shape) > 1:
assert len(asym_id.shape) <= 2
batch_size = asym_id.shape[0]
pair_mask = residue_weights.new_ones((batch_size, n, n), dtype=torch.int32)
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
predicted_tm_term *= pair_mask
pair_residue_weights = pair_mask * (
residue_weights[..., None, :] * residue_weights[..., :, None]
)
denom = eps + torch.sum(pair_residue_weights, dim=-1, keepdims=True)
normed_residue_mask = pair_residue_weights / denom
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
......@@ -671,7 +731,11 @@ def tm_loss(
eps=1e-8,
**kwargs,
):
# first check whether this is a tensor_7 or tensor_4*4
if final_affine_tensor.shape[-1] == 7:
pred_affine = Rigid.from_tensor_7(final_affine_tensor)
elif final_affine_tensor.shape[-1] == 4:
pred_affine = Rigid.from_tensor_4x4(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine):
......@@ -709,7 +773,7 @@ def tm_loss(
(resolution >= min_resolution) & (resolution <= max_resolution)
)
# Average over the loss dimension
# Average over the batch dimension
loss = torch.mean(loss)
return loss
......@@ -784,6 +848,7 @@ def between_residue_bond_loss(
] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
1
]
c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
c_n_loss_per_residue = torch.nn.functional.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev
......@@ -879,6 +944,7 @@ def between_residue_clash_loss(
atom14_atom_exists: torch.Tensor,
atom14_atom_radius: torch.Tensor,
residue_index: torch.Tensor,
asym_id: Optional[torch.Tensor] = None,
overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5,
eps=1e-10,
......@@ -908,7 +974,6 @@ def between_residue_clash_loss(
shape (N, 14)
"""
fp_type = atom14_pred_positions.dtype
# Create the distance matrix.
# (N, N, 14, 14)
dists = torch.sqrt(
......@@ -954,9 +1019,13 @@ def between_residue_clash_loss(
)
n_one_hot = n_one_hot.type(fp_type)
neighbour_mask = (
residue_index[..., :, None, None, None] + 1
) == residue_index[..., None, :, None, None]
neighbour_mask = (residue_index[..., :, None] + 1) == residue_index[..., None, :]
if asym_id is not None:
neighbour_mask = neighbour_mask & (asym_id[..., :, None] == asym_id[..., None, :])
neighbour_mask = neighbour_mask[..., None, None]
c_n_bonds = (
neighbour_mask
* c_one_hot[..., None, None, :, None]
......@@ -998,7 +1067,7 @@ def between_residue_clash_loss(
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
dists_to_low_error, axis=(-3, -1)
dists_to_low_error, dim=(-3, -1)
)
# Compute the hard clash mask.
......@@ -1007,17 +1076,20 @@ def between_residue_clash_loss(
dists < (dists_lower_bound - overlap_tolerance_hard)
)
per_atom_num_clash = torch.sum(clash_mask, dim=(-4, -2)) + torch.sum(clash_mask, dim=(-3, -1))
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask = torch.maximum(
torch.amax(clash_mask, axis=(-4, -2)),
torch.amax(clash_mask, axis=(-3, -1)),
torch.amax(clash_mask, dim=(-4, -2)),
torch.amax(clash_mask, dim=(-3, -1)),
)
return {
"mean_loss": mean_loss, # shape ()
"per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
"per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
"per_atom_num_clash": per_atom_num_clash # shape (N, 14)
}
......@@ -1097,6 +1169,8 @@ def within_residue_violations(
(dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
)
per_atom_num_clash = torch.sum(violations, dim=-2) + torch.sum(violations, dim=-1)
# Compute the per atom violations.
per_atom_violations = torch.maximum(
torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
......@@ -1105,6 +1179,7 @@ def within_residue_violations(
return {
"per_atom_loss_sum": per_atom_loss_sum,
"per_atom_violations": per_atom_violations,
"per_atom_num_clash": per_atom_num_clash
}
......@@ -1134,7 +1209,20 @@ def find_structural_violations(
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
# TODO: Consolidate monomer/multimer modes
asym_id = batch.get("asym_id")
if asym_id is not None:
residx_atom14_to_atom37 = get_rc_tensor(
residue_constants.RESTYPE_ATOM14_TO_ATOM37, batch["aatype"]
)
atom14_atom_radius = (
batch["atom14_atom_exists"]
* atomtype_radius[residx_atom14_to_atom37.long()]
)
else:
atom14_atom_radius = (
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
......@@ -1146,6 +1234,7 @@ def find_structural_violations(
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_atom_radius=atom14_atom_radius,
residue_index=batch["residue_index"],
asym_id=asym_id,
overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=clash_overlap_tolerance,
)
......@@ -1208,6 +1297,9 @@ def find_structural_violations(
"clashes_per_atom_clash_mask": between_residue_clashes[
"per_atom_clash_mask"
], # (N, 14)
"clashes_per_atom_num_clash": between_residue_clashes[
"per_atom_num_clash"
], # (N, 14)
},
"within_residues": {
"per_atom_loss_sum": residue_violations[
......@@ -1216,6 +1308,9 @@ def find_structural_violations(
"per_atom_violations": residue_violations[
"per_atom_violations"
], # (N, 14),
"per_atom_num_clash": residue_violations[
"per_atom_num_clash"
], # (N, 14)
},
"total_per_residue_violations_mask": per_residue_violations_mask, # (N)
}
......@@ -1337,15 +1432,21 @@ def compute_violation_metrics_np(
def violation_loss(
violations: Dict[str, torch.Tensor],
atom14_atom_exists: torch.Tensor,
average_clashes: bool = False,
eps=1e-6,
**kwargs,
) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"]
+ violations["within_residues"]["per_atom_loss_sum"]
)
l_clash = l_clash / (eps + num_atoms)
per_atom_clash = (violations["between_residues"]["clashes_per_atom_loss_sum"] +
violations["within_residues"]["per_atom_loss_sum"])
if average_clashes:
num_clash = (violations["between_residues"]["clashes_per_atom_num_clash"] +
violations["within_residues"]["per_atom_num_clash"])
per_atom_clash = per_atom_clash / (num_clash + eps)
l_clash = torch.sum(per_atom_clash) / (eps + num_atoms)
loss = (
violations["between_residues"]["bonds_c_n_loss_mean"]
+ violations["between_residues"]["angles_ca_c_n_loss_mean"]
......@@ -1491,7 +1592,7 @@ def experimentally_resolved_loss(
return loss
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs):
"""
Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
......@@ -1503,7 +1604,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
Masked MSA loss
"""
errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes)
)
# FP16-friendly averaging. Equivalent to:
......@@ -1524,13 +1625,75 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return loss
def chain_center_of_mass_loss(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
asym_id: torch.Tensor,
clamp_distance: float = -4.0,
weight: float = 0.05,
eps: float = 1e-10, **kwargs
) -> torch.Tensor:
"""
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
Args:
all_atom_pred_pos:
[*, N_pts, 37, 3] All-atom predicted atom positions
all_atom_positions:
[*, N_pts, 37, 3] Ground truth all-atom positions
all_atom_mask:
[*, N_pts, 37] All-atom positions mask
asym_id:
[*, N_pts] Chain asym IDs
clamp_distance:
Cutoff above which distance errors are disregarded
weight:
Weight for loss
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).to(dtype=all_atom_positions.dtype)
def get_chain_center_of_mass(pos):
center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2)
centers = center_sum / (torch.sum(chain_pos_mask, dim=-1, keepdim=True) + eps)
return Vec3Array.from_array(centers)
pred_centers = get_chain_center_of_mass(all_atom_pred_pos) # [B, NC, 3]
true_centers = get_chain_center_of_mass(all_atom_positions) # [B, NC, 3]
pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps)
true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps)
losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2
loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :]
loss = masked_mean(loss_mask, losses, dim=(-1, -2))
return loss
class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement"""
def __init__(self, config):
super(AlphaFoldLoss, self).__init__()
self.config = config
def forward(self, out, batch, _return_breakdown=False):
def loss(self, out, batch, _return_breakdown=False):
"""
Rename previous forward() as loss()
so that can be reused in the subclass
"""
if "violation" not in out.keys():
out["violation"] = find_structural_violations(
batch,
......@@ -1576,31 +1739,36 @@ class AlphaFoldLoss(nn.Module):
),
"violation": lambda: violation_loss(
out["violation"],
**batch,
**{**batch, **self.config.violation},
),
}
if(self.config.tm.enabled):
if self.config.tm.enabled:
loss_fns["tm"] = lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
)
if self.config.chain_center_of_mass.enabled:
loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss(
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.chain_center_of_mass},
)
cum_loss = 0.
losses = {}
for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
#for k,v in batch.items():
# if(torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))):
if torch.isnan(loss) or torch.isinf(loss):
# for k,v in batch.items():
# if torch.any(torch.isnan(v)) or torch.any(torch.isinf(v)):
# logging.warning(f"{k}: is nan")
#logging.warning(f"{loss_name}: {loss}")
# logging.warning(f"{loss_name}: {loss}")
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss
losses[loss_name] = loss.detach().clone()
losses["unscaled_loss"] = cum_loss.detach().clone()
# Scale the loss by the square root of the minimum of the crop size and
......@@ -1611,7 +1779,15 @@ class AlphaFoldLoss(nn.Module):
losses["loss"] = cum_loss.detach().clone()
if(not _return_breakdown):
if not _return_breakdown:
return cum_loss
return cum_loss, losses
def forward(self, out, batch, _return_breakdown=False):
if not _return_breakdown:
cum_loss = self.loss(out, batch, _return_breakdown)
return cum_loss
else:
cum_loss, losses = self.loss(out, batch, _return_breakdown)
return cum_loss, losses
import logging
import random
import torch
from openfold.np import residue_constants as rc
logger = logging.getLogger(__name__)
def compute_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor = None,
eps: float = 1e-6,
) -> torch.Tensor:
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
if atom_mask is not None:
sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device))
msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps) # prevent sqrt 0
def kabsch_rotation(P, Q):
"""
Calculate the best rotation that minimises the RMSD between P and Q.
The optimal rotation matrix was calculated using Kabsch algorithm:
https://en.wikipedia.org/wiki/Kabsch_algorithm
Args:
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
Q: [N * 3] the same dimension as P
return:
A 3*3 rotation matrix
"""
assert P.shape == torch.Size([Q.shape[0], Q.shape[1]])
# Firstly, compute SVD of P.T * Q
u, _, vt = torch.linalg.svd(torch.matmul(P.to(torch.float32).T,
Q.to(torch.float32)))
# Then construct s matrix
s = torch.eye(P.shape[1], device=P.device)
# correct the rotation matrix to ensure a right-handed coordinate
s[-1, -1] = torch.sign(torch.linalg.det(torch.matmul(u, vt)))
# finally compute the rotation matrix
r_opt = torch.matmul(torch.matmul(u, s), vt)
assert r_opt.shape == torch.Size([3,3])
return r_opt.to(device=P.device, dtype=P.dtype)
def get_optimal_transform(
src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor,
mask: torch.Tensor = None,
):
"""
src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res]
"""
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3
if mask is not None:
assert len(mask.shape) == 1, "mask should have the shape of [num_res]"
if torch.isnan(src_atoms).any() or torch.isinf(src_atoms).any():
#
# sometimes using fake test inputs generates NaN in the predicted atom positions
# #
logging.warning(f"src_atom has nan or inf")
src_atoms = torch.nan_to_num(src_atoms, nan=0.0, posinf=1.0, neginf=1.0)
if mask is not None:
assert mask.dtype == torch.bool
assert mask.shape[-1] == src_atoms.shape[-2]
if mask.sum() == 0:
src_atoms = torch.zeros((1, 3), device=src_atoms.device, dtype=src_atoms.dtype)
tgt_atoms = src_atoms
else:
src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True, dtype=src_atoms.dtype)
tgt_center = tgt_atoms.mean(-2, keepdim=True, dtype=src_atoms.dtype)
r = kabsch_rotation(src_atoms, tgt_atoms)
x = tgt_center - src_center @ r
return r, x
def get_least_asym_entity_or_longest_length(batch, input_asym_id):
"""
First check how many subunit(s) one sequence has. Select the subunit that is less
common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
Args:
batch: in this function batch is the full ground truth features
input_asym_id: A list of asym_ids that are in the cropped input features
Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
"""
entity_2_asym_list = get_entity_2_asym_list(batch)
unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {}
entity_length = {}
for entity_id in unique_entity_ids:
asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id])
# Make sure some asym IDs associated with ground truth entity ID exist in cropped prediction
asym_ids_in_pred = [a for a in asym_ids if a in input_asym_id]
if not asym_ids_in_pred:
continue
entity_asym_count[int(entity_id)] = len(asym_ids)
# Calculate entity length
entity_mask = (batch["entity_id"] == entity_id)
entity_length[int(entity_id)] = entity_mask.sum().item()
min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
# If still multiple entities, return a random one
if len(least_asym_entities) > 1:
least_asym_entities = [random.choice(least_asym_entities)]
assert len(least_asym_entities) == 1
least_asym_entities = least_asym_entities[0]
anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities])
anchor_pred_asym_ids = [asym_id for asym_id in entity_2_asym_list[least_asym_entities] if asym_id in input_asym_id]
return anchor_gt_asym_id, anchor_pred_asym_ids
def greedy_align(
batch,
per_asym_residue_index,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
true_ca_poses,
true_ca_masks,
):
"""
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
"""
used = [False for _ in range(len(true_ca_poses))]
align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0]
for cur_asym_id in unique_asym_ids:
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
cur_entity_ids = batch["entity_id"][asym_mask][0]
best_rmsd = torch.inf
best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list:
j = int(next_asym_id - 1)
if not used[j]: # possible candidate
cropped_pos = torch.index_select(true_ca_poses[j], 1, cur_residue_index)
mask = torch.index_select(true_ca_masks[j], 1, cur_residue_index)
rmsd = compute_rmsd(
torch.squeeze(cropped_pos, 0), torch.squeeze(cur_pred_pos, 0),
(cur_pred_mask * mask).bool()
)
if rmsd is not None and rmsd < best_rmsd:
best_rmsd = rmsd
best_idx = j
assert best_idx is not None
used[best_idx] = True
align.append((i, best_idx))
return align
def pad_features(feature_tensor, nres_pad, pad_dim):
"""Pad input feature tensor"""
pad_shape = list(feature_tensor.shape)
pad_shape[pad_dim] = nres_pad
padding_tensor = feature_tensor.new_zeros(pad_shape, device=feature_tensor.device)
return torch.concat((feature_tensor, padding_tensor), dim=pad_dim)
def merge_labels(per_asym_residue_index, labels, align, original_nres):
"""
Merge ground truth labels according to the permutation results
labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
modified based on UniFold:
https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1
"""
outs = {}
for k, v in labels[0].items():
cur_out = {}
for i, j in align:
label = labels[j][k]
# to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape) <= 1 or "template" in k or "row_mask" in k:
continue
else:
dimension_to_merge = 1
cur_out[i] = label.index_select(dimension_to_merge, cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out) > 0:
new_v = torch.concat(cur_out, dim=dimension_to_merge)
# below check whether padding is needed
if new_v.shape[dimension_to_merge] != original_nres:
nres_pad = original_nres - new_v.shape[dimension_to_merge]
new_v = pad_features(new_v, nres_pad, pad_dim=dimension_to_merge)
outs[k] = new_v
return outs
def split_ground_truth_labels(gt_features):
"""
Splits ground truth features according to chains
Returns:
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
"""
unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True)
n_res = gt_features["asym_id"].shape[-1]
def split_dim(shape):
return next(iter(i for i, size in enumerate(shape) if size == n_res), None)
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(v_all, asym_id_counts.tolist(),
dim=split_dim(v_all.shape))]
for k, v_all in gt_features.items()
if n_res in v_all.shape])))
return labels
def get_per_asym_residue_index(features):
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i != 0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (features["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(features["residue_index"], asym_mask)
return per_asym_residue_index
def get_entity_2_asym_list(batch):
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list = {}
unique_entity_ids = torch.unique(batch["entity_id"])
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
return entity_2_asym_list
def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue,
asym_mask, pred_ca_mask):
"""
Calculate an input mask for downstream optimal transformation computation
Args:
true_ca_masks (Tensor): ca mask from ground truth.
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask = torch.squeeze(pred_ca_mask, 0)
asym_mask = torch.squeeze(asym_mask, 0)
anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_gt_residue)
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask
def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx, anchor_gt_residue,
true_ca_masks, pred_ca_mask,
asym_mask,
pred_ca_pos):
input_mask = calculate_input_mask(true_ca_masks,
anchor_gt_idx,
anchor_gt_residue,
asym_mask,
pred_ca_mask)
input_mask = torch.squeeze(input_mask, 0)
pred_ca_pos = torch.squeeze(pred_ca_pos, 0)
asym_mask = torch.squeeze(asym_mask, 0)
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_gt_residue)
anchor_pred_pos = pred_ca_pos[asym_mask]
r, x = get_optimal_transform(
anchor_pred_pos, torch.squeeze(anchor_true_pos, 0),
mask=input_mask
)
return r, x
def compute_permutation_alignment(out, features, ground_truth):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
unique_asym_ids = set(torch.unique(features['asym_id']).tolist())
unique_asym_ids.discard(0) # Remove padding asym_id
is_monomer = len(unique_asym_ids) == 1
per_asym_residue_index = get_per_asym_residue_index(features)
if is_monomer:
best_align = list(enumerate(range(len(per_asym_residue_index))))
return best_align, per_asym_residue_index
best_rmsd = float('inf')
best_align = None
# First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,
features['asym_id'])
entity_2_asym_list = get_entity_2_asym_list(ground_truth)
labels = split_ground_truth_labels(ground_truth)
assert isinstance(labels, list)
anchor_gt_idx = int(anchor_gt_asym) - 1
# Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
for candidate_pred_anchor in anchor_pred_asym_ids:
asym_mask = (features["asym_id"] == candidate_pred_anchor).bool()
anchor_gt_residue = per_asym_residue_index[candidate_pred_anchor.item()]
r, x = calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,
anchor_gt_residue,
true_ca_masks,
pred_ca_mask,
asym_mask,
pred_ca_pos)
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align(
features,
per_asym_residue_index,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
merged_labels = merge_labels(per_asym_residue_index, labels, align,
original_nres=features['aatype'].shape[-1])
rmsd = compute_rmsd(
true_atom_pos=merged_labels['all_atom_positions'][..., ca_idx, :].to(r.dtype) @ r + x,
pred_atom_pos=pred_ca_pos,
atom_mask=(pred_ca_mask * merged_labels['all_atom_mask'][..., ca_idx].long()).bool())
if rmsd < best_rmsd:
best_rmsd = rmsd
best_align = align
return best_align, per_asym_residue_index
def multi_chain_permutation_align(out, features, ground_truth):
"""Compute multi-chain permutation alignment.
Args:
out: The output of model.forward()
features: Input features
ground_truth: Ground truth features
"""
labels = split_ground_truth_labels(ground_truth)
# Then permute ground truth chains before calculating the loss
align, per_asym_residue_index = compute_permutation_alignment(out=out,
features=features,
ground_truth=ground_truth)
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index, labels, align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
return features
......@@ -978,6 +978,16 @@ class Rigid:
"""
return self._trans.device
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the Rigid tensors.
Returns:
The dtype of the Rigid tensors
"""
return self._rots.dtype
def get_rots(self) -> Rotation:
"""
Getter for the rotation.
......
......@@ -12,6 +12,7 @@ from openfold.np import residue_constants, protein
from openfold.np.relax import relax
from openfold.utils.import_weights import (
import_jax_weights_,
import_openfold_weights_
)
from pytorch_lightning.utilities.deepspeed import (
......@@ -90,7 +91,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
ckpt_path,
)
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
import_openfold_weights_(model=model, state_dict=d["ema"]["params"])
else:
ckpt_path = path
d = torch.load(ckpt_path)
......@@ -98,7 +99,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
if "ema" in d:
# The public weights have had this done to them already
d = d["ema"]["params"]
model.load_state_dict(d)
import_openfold_weights_(model=model, state_dict=d)
model = model.to(model_device)
logger.info(
......@@ -122,7 +123,7 @@ def parse_fasta(data):
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags]
tags = [re.split('\W| \|', t)[0] for t in tags]
return tags, seqs
......@@ -219,7 +220,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult
features=batch,
result=out,
b_factors=plddt_b_factors,
chain_index=chain_index,
remove_leading_feature_dimension=False,
remark=remark,
parents=template_domain_names,
parents_chain_index=template_chain_index,
......
......@@ -17,24 +17,19 @@ import logging
import math
import numpy as np
import os
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
update_timings, relax_protein
import pickle
import random
import time
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
import pickle
import random
import time
import torch
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
if (
torch_major_version > 1 or
(torch_major_version == 1 and torch_minor_version >= 12)
):
......@@ -45,16 +40,16 @@ torch.set_grad_enabled(False)
from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax
from openfold.utils.tensor_utils import (
tensor_tree_map,
)
from openfold.data.tools import hhsearch, hmmsearch
from openfold.np import protein
from openfold.utils.script_utils import (load_models_from_command_line, parse_fasta, run_model,
prep_output, relax_protein)
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.trace_utils import (
pad_feature_dict_seq,
trace_model_,
)
from scripts.precompute_embeddings import EmbeddingGenerator
from scripts.utils import add_data_args
......@@ -69,18 +64,30 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
if args.use_precomputed_alignments is None:
logger.info(f"Generating alignments for {tag}...")
os.makedirs(local_alignment_dir)
os.makedirs(local_alignment_dir, exist_ok=True)
if "multimer" in args.config_preset:
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
# In seqemb mode, use AlignmentRunner only to generate templates
if args.use_single_seq_mode:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
pdb70_database_path=args.pdb70_database_path,
template_searcher=template_searcher,
no_cpus=args.cpus,
)
embedding_generator = EmbeddingGenerator()
......@@ -89,14 +96,17 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
no_cpus=args.cpus,
uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus
)
alignment_runner.run(
tmp_fasta_path, local_alignment_dir
)
......@@ -133,6 +143,14 @@ def generate_feature_dict(
alignment_dir=local_alignment_dir,
seqemb_mode=args.use_single_seq_mode,
)
elif "multimer" in args.config_preset:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
)
else:
with open(tmp_fasta_path, "w") as fp:
fp.write(
......@@ -147,6 +165,7 @@ def generate_feature_dict(
return feature_dict
def list_files_with_extensions(dir, extensions):
return [f for f in os.listdir(dir) if f.endswith(extensions)]
......@@ -157,15 +176,28 @@ def main(args):
if args.config_preset.startswith("seq"):
args.use_single_seq_mode = True
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
if(args.trace_model):
if(not config.data.predict.fixed_size):
if args.trace_model:
if not config.data.predict.fixed_size:
raise ValueError(
"Tracing requires that fixed_size mode be enabled in the config"
)
template_featurizer = templates.TemplateHitFeaturizer(
is_multimer = "multimer" in args.config_preset
if is_multimer:
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
else:
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
......@@ -178,10 +210,15 @@ def main(args):
template_featurizer=template_featurizer,
)
if is_multimer:
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor,
)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(2**32)
random_seed = random.randrange(2 ** 32)
np.random.seed(random_seed)
torch.manual_seed(random_seed + 1)
......@@ -198,10 +235,19 @@ def main(args):
seq_list = []
for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
# Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
fasta_path = os.path.join(args.fasta_dir, fasta_file)
with open(fasta_path, "r") as fp:
data = fp.read()
tags, seqs = parse_fasta(data)
if not is_multimer and len(tags) != 1:
print(
f"{fasta_path} contains more than one sequence but "
f"multimer mode is not enabled. Skipping..."
)
continue
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
......@@ -217,6 +263,7 @@ def main(args):
args.openfold_checkpoint_path,
args.jax_param_path,
args.output_dir)
for model, output_directory in model_generator:
cur_tracing_interval = 0
for (tag, tags), seqs in sorted_targets:
......@@ -228,7 +275,7 @@ def main(args):
precompute_alignments(tags, seqs, alignment_dir, args)
feature_dict = feature_dicts.get(tag, None)
if(feature_dict is None):
if feature_dict is None:
feature_dict = generate_feature_dict(
tags,
seqs,
......@@ -237,7 +284,7 @@ def main(args):
args,
)
if(args.trace_model):
if args.trace_model:
n = feature_dict["aatype"].shape[-2]
rounded_seqlen = round_up_seqlen(n)
feature_dict = pad_feature_dict_seq(
......@@ -247,16 +294,16 @@ def main(args):
feature_dicts[tag] = feature_dict
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
feature_dict, mode='predict', is_multimer=is_multimer
)
processed_feature_dict = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in processed_feature_dict.items()
k: torch.as_tensor(v, device=args.model_device)
for k, v in processed_feature_dict.items()
}
if(args.trace_model):
if(rounded_seqlen > cur_tracing_interval):
if args.trace_model:
if rounded_seqlen > cur_tracing_interval:
logger.info(
f"Tracing model at {rounded_seqlen} residues..."
)
......@@ -305,7 +352,8 @@ def main(args):
if not args.skip_relaxation:
# Relax the prediction.
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, args.cif_output)
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name,
args.cif_output)
if args.save_outputs:
output_dict_path = os.path.join(
......@@ -407,13 +455,13 @@ if __name__ == "__main__":
add_data_args(parser)
args = parser.parse_args()
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
if args.jax_param_path is None and args.openfold_checkpoint_path is None:
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.config_preset + ".npz"
)
if(args.model_device == "cpu" and torch.cuda.is_available()):
if args.model_device == "cpu" and torch.cuda.is_available():
logging.warning(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
......
......@@ -26,6 +26,7 @@ from openfold.utils.import_weights import (
ParamType,
generate_translation_dict,
process_translation_dict,
import_openfold_weights_
)
from openfold.utils.tensor_utils import tree_map
......@@ -63,7 +64,7 @@ def main(args):
config = model_config(args.config_preset)
model = AlphaFold(config)
model.load_state_dict(d)
import_openfold_weights_(model=model, state_dict=d)
translation = generate_translation_dict(model, args.config_preset)
translation = process_translation_dict(translation)
......
import argparse
import logging
import os
import string
from collections import defaultdict
from openfold.data import mmcif_parsing
from openfold.np import protein, residue_constants
......@@ -22,7 +23,7 @@ def main(args):
if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {fname}...')
if(args.raise_errors):
raise list(mmcif.errors.values())[0]
raise Exception(list(mmcif.errors.values())[0])
else:
continue
......@@ -31,6 +32,25 @@ def main(args):
chain_id = '_'.join([basename, chain])
fasta.append(f">{chain_id}")
fasta.append(seq)
elif(ext == ".pdb"):
with open(fpath, 'r') as fp:
pdb_str = fp.read()
protein_object = protein.from_pdb_string(pdb_str)
aatype = protein_object.aatype
chain_index = protein_object.chain_index
last_chain_index = chain_index[0]
chain_dict = defaultdict(list)
for i in range(aatype.shape[0]):
chain_dict[chain_index[i]].append(residue_constants.restypes_with_x[aatype[i]])
chain_tags = string.ascii_uppercase
for chain, seq in chain_dict.items():
chain_id = '_'.join([basename, chain_tags[chain]])
fasta.append(f">{chain_id}")
fasta.append(''.join(seq))
elif(ext == ".core"):
with open(fpath, 'r') as fp:
core_str = fp.read()
......
import copy
import os
import torch
import deepspeed
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.ml = torch.nn.ModuleList()
for _ in range(4000):
self.ml.append(torch.nn.Linear(500, 500))
def forward(self, batch):
for i, l in enumerate(self.ml):
# print(f"{i}: {l.weight.device}")
batch = l(batch)
return batch
class DummyDataset(torch.utils.data.Dataset):
def __init__(self):
self.batch = torch.rand(500, 500)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
dd = DummyDataset()
dl = torch.utils.data.DataLoader(dd)
example = next(iter(dl)).to(f"cuda:{local_rank}")
model = Model()
model = model.to(f"cuda:{local_rank}")
model = deepspeed.init_inference(
model,
mp_size=world_size,
checkpoint=None,
replace_method=None,
#replace_method="auto"
)
out = model(example)
#if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# print(out)
......@@ -56,10 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB mmCIF files..."
bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniclust30..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref30..."
bash "${SCRIPT_DIR}/download_uniref30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment