Commit 56d5e39c authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

Merge remote-tracking branch 'upstream/multimer' into multimer

parents 56b86074 51556d52
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vec3Array Class."""
from __future__ import annotations
import dataclasses
from typing import Union, List
import torch
from openfold.utils.geometry import utils
Float = Union[float, torch.Tensor]
@dataclasses.dataclass(frozen=True)
class Vec3Array:
x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
y: torch.Tensor
z: torch.Tensor
def __post_init__(self):
if hasattr(self.x, 'dtype'):
assert self.x.dtype == self.y.dtype
assert self.x.dtype == self.z.dtype
assert all([x == y for x, y in zip(self.x.shape, self.y.shape)])
assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
def __add__(self, other: Vec3Array) -> Vec3Array:
return Vec3Array(
self.x + other.x,
self.y + other.y,
self.z + other.z,
)
def __sub__(self, other: Vec3Array) -> Vec3Array:
return Vec3Array(
self.x - other.x,
self.y - other.y,
self.z - other.z,
)
def __mul__(self, other: Float) -> Vec3Array:
return Vec3Array(
self.x * other,
self.y * other,
self.z * other,
)
def __rmul__(self, other: Float) -> Vec3Array:
return self * other
def __truediv__(self, other: Float) -> Vec3Array:
return Vec3Array(
self.x / other,
self.y / other,
self.z / other,
)
def __neg__(self) -> Vec3Array:
return self * -1
def __pos__(self) -> Vec3Array:
return self * 1
def __getitem__(self, index) -> Vec3Array:
return Vec3Array(
self.x[index],
self.y[index],
self.z[index],
)
def __iter__(self):
return iter((self.x, self.y, self.z))
@property
def shape(self):
return self.x.shape
def map_tensor_fn(self, fn) -> Vec3Array:
return Vec3Array(
fn(self.x),
fn(self.y),
fn(self.z),
)
def cross(self, other: Vec3Array) -> Vec3Array:
"""Compute cross product between 'self' and 'other'."""
new_x = self.y * other.z - self.z * other.y
new_y = self.z * other.x - self.x * other.z
new_z = self.x * other.y - self.y * other.x
return Vec3Array(new_x, new_y, new_z)
def dot(self, other: Vec3Array) -> Float:
"""Compute dot product between 'self' and 'other'."""
return self.x * other.x + self.y * other.y + self.z * other.z
def norm(self, epsilon: float = 1e-6) -> Float:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2 = self.dot(self)
if epsilon:
norm2 = torch.clamp(norm2, min=epsilon**2)
return torch.sqrt(norm2)
def norm2(self):
return self.dot(self)
def normalized(self, epsilon: float = 1e-6) -> Vec3Array:
"""Return unit vector with optional clipping."""
return self / self.norm(epsilon)
def clone(self) -> Vec3Array:
return Vec3Array(
self.x.clone(),
self.y.clone(),
self.z.clone(),
)
def reshape(self, new_shape) -> Vec3Array:
x = self.x.reshape(new_shape)
y = self.y.reshape(new_shape)
z = self.z.reshape(new_shape)
return Vec3Array(x, y, z)
def sum(self, dim: int) -> Vec3Array:
return Vec3Array(
torch.sum(self.x, dim=dim),
torch.sum(self.y, dim=dim),
torch.sum(self.z, dim=dim),
)
def unsqueeze(self, dim: int):
return Vec3Array(
self.x.unsqueeze(dim),
self.y.unsqueeze(dim),
self.z.unsqueeze(dim),
)
@classmethod
def zeros(cls, shape, device="cpu"):
"""Return Vec3Array corresponding to zeros of given shape."""
return cls(
torch.zeros(shape, dtype=torch.float32, device=device),
torch.zeros(shape, dtype=torch.float32, device=device),
torch.zeros(shape, dtype=torch.float32, device=device)
)
def to_tensor(self) -> torch.Tensor:
return torch.stack([self.x, self.y, self.z], dim=-1)
@classmethod
def from_array(cls, tensor):
return cls(*torch.unbind(tensor, dim=-1))
@classmethod
def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array:
return cls(
torch.cat([v.x for v in vecs], dim=dim),
torch.cat([v.y for v in vecs], dim=dim),
torch.cat([v.z for v in vecs], dim=dim),
)
def square_euclidean_distance(
vec1: Vec3Array,
vec2: Vec3Array,
epsilon: float = 1e-6
) -> Float:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference = vec1 - vec2
distance = difference.dot(difference)
if epsilon:
distance = torch.clamp(distance, min=epsilon)
return distance
def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.dot(vector2)
def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
return vector1.cross(vector2)
def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
return vector.norm(epsilon)
def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:
return vector.normalized(epsilon)
def euclidean_distance(
vec1: Vec3Array,
vec2: Vec3Array,
epsilon: float = 1e-6
) -> Float:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2)
distance = torch.sqrt(distance_sq)
return distance
def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array,
d: Vec3Array) -> Float:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1 = a - b
v2 = b - c
v3 = d - c
c1 = v1.cross(v2)
c2 = v3.cross(v2)
c3 = c2.cross(c1)
v2_mag = v2.norm()
return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2))
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -39,6 +40,13 @@ class ParamType(Enum): ...@@ -39,6 +40,13 @@ class ParamType(Enum):
LinearWeightOPM = partial( LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
) )
LinearWeightMultimer = partial(
lambda w: w.unsqueeze(-1) if len(w.shape) == 1 else
w.reshape(w.shape[0], -1).transpose(-1, -2)
)
LinearBiasMultimer = partial(
lambda w: w.reshape(-1)
)
Other = partial(lambda w: w) Other = partial(lambda w: w)
def __init__(self, fn): def __init__(self, fn):
...@@ -122,26 +130,32 @@ def assign(translation_dict, orig_weights): ...@@ -122,26 +130,32 @@ def assign(translation_dict, orig_weights):
raise raise
def generate_translation_dict(model, version): def generate_translation_dict(model, version, is_multimer=False):
####################### #######################
# Some templates # Some templates
####################### #######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight)) LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l)) LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA)) LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA)) LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM)) LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearWeightMultimer = lambda l: (
Param(l, param_type=ParamType.LinearWeightMultimer)
)
LinearBiasMultimer = lambda l: (
Param(l, param_type=ParamType.LinearBiasMultimer)
)
LinearParams = lambda l: { LinearParams = lambda l: {
"weights": LinearWeight(l.weight), "weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias), "bias": LinearBias(l.bias),
} }
LinearParamsMultimer = lambda l: {
"weights": LinearWeightMultimer(l.weight),
"bias": LinearBiasMultimer(l.bias),
}
LayerNormParams = lambda l: { LayerNormParams = lambda l: {
"scale": Param(l.weight), "scale": Param(l.weight),
"offset": Param(l.bias), "offset": Param(l.bias),
...@@ -178,31 +192,47 @@ def generate_translation_dict(model, version): ...@@ -178,31 +192,47 @@ def generate_translation_dict(model, version):
"attention": AttentionGatedParams(tri_att.mha), "attention": AttentionGatedParams(tri_att.mha),
} }
TriMulOutParams = lambda tri_mul: { def TriMulOutParams(tri_mul, outgoing=True):
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), if re.fullmatch("^model_[1-5]_multimer_v3$", version):
"left_projection": LinearParams(tri_mul.linear_a_p), d = {
"right_projection": LinearParams(tri_mul.linear_b_p), "left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_gate": LinearParams(tri_mul.linear_a_g), "projection": LinearParams(tri_mul.linear_ab_p),
"right_gate": LinearParams(tri_mul.linear_b_g), "gate": LinearParams(tri_mul.linear_ab_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out), "center_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z), }
"gating_linear": LinearParams(tri_mul.linear_g), else:
} # see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
if outgoing:
left_projection = LinearParams(tri_mul.linear_a_p)
right_projection = LinearParams(tri_mul.linear_b_p)
left_gate = LinearParams(tri_mul.linear_a_g)
right_gate = LinearParams(tri_mul.linear_b_g)
else:
left_projection = LinearParams(tri_mul.linear_b_p)
right_projection = LinearParams(tri_mul.linear_a_p)
left_gate = LinearParams(tri_mul.linear_b_g)
right_gate = LinearParams(tri_mul.linear_a_g)
d = {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": left_projection,
"right_projection": right_projection,
"left_gate": left_gate,
"right_gate": right_gate,
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
}
# see commit b88f8da on the Alphafold repo d.update({
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming "output_projection": LinearParams(tri_mul.linear_z),
# iterations of triangle multiplication, which is confusing and not "gating_linear": LinearParams(tri_mul.linear_g),
# reproduced in our implementation. })
TriMulInParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in), return d
"left_projection": LinearParams(tri_mul.linear_b_p),
"right_projection": LinearParams(tri_mul.linear_a_p), TriMulInParams = partial(TriMulOutParams, outgoing=False)
"left_gate": LinearParams(tri_mul.linear_b_g),
"right_gate": LinearParams(tri_mul.linear_a_g),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
PairTransitionParams = lambda pt: { PairTransitionParams = lambda pt: {
"input_layer_norm": LayerNormParams(pt.layer_norm), "input_layer_norm": LayerNormParams(pt.layer_norm),
...@@ -236,8 +266,46 @@ def generate_translation_dict(model, version): ...@@ -236,8 +266,46 @@ def generate_translation_dict(model, version):
IPAParams = lambda ipa: { IPAParams = lambda ipa: {
"q_scalar": LinearParams(ipa.linear_q), "q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv), "kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points), "q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local": LinearParams(ipa.linear_kv_points), "kv_point_local": LinearParams(ipa.linear_kv_points.linear),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
PointProjectionParams = lambda pp: {
"point_projection": LinearParamsMultimer(
pp.linear,
),
}
IPAParamsMultimer = lambda ipa: {
"q_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_q.weight,
),
},
"k_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
),
},
"v_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_v.weight,
),
},
"q_point_projection": PointProjectionParams(
ipa.linear_q_points
),
"k_point_projection": PointProjectionParams(
ipa.linear_k_points
),
"v_point_projection": PointProjectionParams(
ipa.linear_v_points
),
"trainable_point_weights": Param( "trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other param=ipa.head_weights, param_type=ParamType.Other
), ),
...@@ -280,109 +348,183 @@ def generate_translation_dict(model, version): ...@@ -280,109 +348,183 @@ def generate_translation_dict(model, version):
b.msa_att_row b.msa_att_row
), ),
col_att_name: msa_col_att_params, col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.core.msa_transition), "msa_transition": MSATransitionParams(b.msa_transition),
"outer_product_mean": "outer_product_mean":
OuterProductMeanParams(b.core.outer_product_mean), OuterProductMeanParams(b.outer_product_mean),
"triangle_multiplication_outgoing": "triangle_multiplication_outgoing":
TriMulOutParams(b.core.tri_mul_out), TriMulOutParams(b.pair_stack.tri_mul_out),
"triangle_multiplication_incoming": "triangle_multiplication_incoming":
TriMulInParams(b.core.tri_mul_in), TriMulInParams(b.pair_stack.tri_mul_in),
"triangle_attention_starting_node": "triangle_attention_starting_node":
TriAttParams(b.core.tri_att_start), TriAttParams(b.pair_stack.tri_att_start),
"triangle_attention_ending_node": "triangle_attention_ending_node":
TriAttParams(b.core.tri_att_end), TriAttParams(b.pair_stack.tri_att_end),
"pair_transition": "pair_transition":
PairTransitionParams(b.core.pair_transition), PairTransitionParams(b.pair_stack.pair_transition),
} }
return d return d
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True) ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: { def FoldIterationParams(sm):
"invariant_point_attention": IPAParams(sm.ipa), d = {
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa), "invariant_point_attention":
"transition": LinearParams(sm.transition.layers[0].linear_1), IPAParamsMultimer(sm.ipa) if is_multimer else IPAParams(sm.ipa),
"transition_1": LinearParams(sm.transition.layers[0].linear_2), "attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition_2": LinearParams(sm.transition.layers[0].linear_3), "transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_layer_norm": LayerNormParams(sm.transition.layer_norm), "transition_1": LinearParams(sm.transition.layers[0].linear_2),
"affine_update": LinearParams(sm.bb_update.linear), "transition_2": LinearParams(sm.transition.layers[0].linear_3),
"rigid_sidechain": { "transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
"input_projection": LinearParams(sm.angle_resnet.linear_in), "affine_update": LinearParams(sm.bb_update.linear),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial), "rigid_sidechain": {
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1), "input_projection": LinearParams(sm.angle_resnet.linear_in),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2), "input_projection_1":
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1), LinearParams(sm.angle_resnet.linear_initial),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2), "resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out), "resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
}, "resblock1_1":
} LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1":
LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles":
LinearParams(sm.angle_resnet.linear_out),
},
}
if(is_multimer):
d.pop("affine_update")
d["quat_rigid"] = {
"rigid": LinearParams(
sm.bb_update.linear
)
}
return d
############################ ############################
# translations dict overflow # translations dict overflow
############################ ############################
ems_blocks = model.extra_msa_stack.blocks ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks]) ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer.blocks evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks]) evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
translations = { if(not is_multimer):
"evoformer": { translations = {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m), "evoformer": {
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m), "preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i), "preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j), "left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear), "right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_msa_first_row_norm": LayerNormParams( "prev_pos_linear": LinearParams(model.recycling_embedder.linear),
model.recycling_embedder.layer_norm_m "prev_msa_first_row_norm": LayerNormParams(
), model.recycling_embedder.layer_norm_m
"prev_pair_norm": LayerNormParams( ),
model.recycling_embedder.layer_norm_z "prev_pair_norm": LayerNormParams(
), model.recycling_embedder.layer_norm_z
"pair_activiations": LinearParams( ),
model.input_embedder.linear_relpos "pair_activiations": LinearParams(
), model.input_embedder.linear_relpos
"extra_msa_activations": LinearParams( ),
model.extra_msa_embedder.linear "extra_msa_activations": LinearParams(
), model.extra_msa_embedder.linear
"extra_msa_stack": ems_blocks_params, ),
"evoformer_iteration": evo_blocks_params, "extra_msa_stack": ems_blocks_params,
"single_activations": LinearParams(model.evoformer.linear), "evoformer_iteration": evo_blocks_params,
}, "single_activations": LinearParams(model.evoformer.linear),
"structure_module": { },
"single_layer_norm": LayerNormParams( "structure_module": {
model.structure_module.layer_norm_s "single_layer_norm": LayerNormParams(
), model.structure_module.layer_norm_s
"initial_projection": LinearParams( ),
model.structure_module.linear_in "initial_projection": LinearParams(
), model.structure_module.linear_in
"pair_layer_norm": LayerNormParams( ),
model.structure_module.layer_norm_z "pair_layer_norm": LayerNormParams(
), model.structure_module.layer_norm_z
"fold_iteration": FoldIterationParams(model.structure_module), ),
}, "fold_iteration": FoldIterationParams(model.structure_module),
"predicted_lddt_head": { },
"input_layer_norm": LayerNormParams( "predicted_lddt_head": {
model.aux_heads.plddt.layer_norm "input_layer_norm": LayerNormParams(
), model.aux_heads.plddt.layer_norm
"act_0": LinearParams(model.aux_heads.plddt.linear_1), ),
"act_1": LinearParams(model.aux_heads.plddt.linear_2), "act_0": LinearParams(model.aux_heads.plddt.linear_1),
"logits": LinearParams(model.aux_heads.plddt.linear_3), "act_1": LinearParams(model.aux_heads.plddt.linear_2),
}, "logits": LinearParams(model.aux_heads.plddt.linear_3),
"distogram_head": { },
"half_logits": LinearParams(model.aux_heads.distogram.linear), "distogram_head": {
}, "half_logits": LinearParams(model.aux_heads.distogram.linear),
"experimentally_resolved_head": { },
"logits": LinearParams( "experimentally_resolved_head": {
model.aux_heads.experimentally_resolved.linear "logits": LinearParams(
), model.aux_heads.experimentally_resolved.linear
}, ),
"masked_msa_head": { },
"logits": LinearParams(model.aux_heads.masked_msa.linear), "masked_msa_head": {
}, "logits": LinearParams(model.aux_heads.masked_msa.linear),
} },
}
else:
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"~_relative_encoding": {
"position_activations": LinearParams(
model.input_embedder.linear_relpos
),
},
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
no_templ = [ no_templ = [
"model_3", "model_3",
...@@ -394,48 +536,98 @@ def generate_translation_dict(model, version): ...@@ -394,48 +536,98 @@ def generate_translation_dict(model, version):
] ]
if version not in no_templ: if version not in no_templ:
tps_blocks = model.template_pair_stack.blocks tps_blocks = model.template_embedder.template_pair_stack.blocks
tps_blocks_params = stacked( tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks] [TemplatePairBlockParams(b) for b in tps_blocks]
) )
template_param_dict = { if (not is_multimer):
"template_embedding": { template_param_dict = {
"single_template_embedding": { "template_embedding": {
"embedding2d": LinearParams( "single_template_embedding": {
model.template_pair_embedder.linear "embedding2d": LinearParams(
), model.template_embedder.template_pair_embedder.linear
"template_pair_stack": { ),
"__layer_stack_no_state": tps_blocks_params, "template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
}, },
"output_layer_norm": LayerNormParams( "attention": AttentionParams(model.template_embedder.template_pointwise_att.mha),
model.template_pair_stack.layer_norm },
"template_single_embedding": LinearParams(
model.template_embedder.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_embedder.template_angle_embedder.linear_2
),
}
else:
temp_embedder = model.template_embedder
template_param_dict = {
"template_embedding": {
"single_template_embedding": {
"query_embedding_norm": LayerNormParams(
temp_embedder.template_pair_embedder.query_embedding_layer_norm
),
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParamsMultimer(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_1
),
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParamsMultimer(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParamsMultimer(
temp_embedder.template_pair_embedder.y_linear
),
"template_pair_embedding_6": LinearParamsMultimer(
temp_embedder.template_pair_embedder.z_linear
),
"template_pair_embedding_7": LinearParamsMultimer(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
temp_embedder.template_pair_embedder.query_embedding_linear
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"output_linear": LinearParams(
temp_embedder.linear_t
), ),
}, },
"attention": AttentionParams(model.template_pointwise_att.mha), "template_projection": LinearParams(
}, temp_embedder.template_single_embedder.template_projector,
"template_single_embedding": LinearParams( ),
model.template_angle_embedder.linear_1 "template_single_embedding": LinearParams(
), temp_embedder.template_single_embedder.template_single_embedder,
"template_projection": LinearParams( ),
model.template_angle_embedder.linear_2 }
),
} translations["evoformer"].update(template_param_dict)
translations["evoformer"].update(template_param_dict)
if "_ptm" in version: if is_multimer or "_ptm" in version:
translations["predicted_aligned_error_head"] = { translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear) "logits": LinearParams(model.aux_heads.tm.linear)
} }
return translations return translations
def import_jax_weights_(model, npz_path, version="model_1"): def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path) data = np.load(npz_path)
translations = generate_translation_dict(model, version, is_multimer=("multimer" in version))
translations = generate_translation_dict(model, version)
# Flatten keys and insert missing key prefixes # Flatten keys and insert missing key prefixes
flat = process_translation_dict(translations) flat = process_translation_dict(translations)
......
...@@ -25,6 +25,8 @@ from typing import Dict, Optional, Tuple ...@@ -25,6 +25,8 @@ from typing import Dict, Optional, Tuple
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils import feats from openfold.utils import feats
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.vector import Vec3Array, euclidean_distance
from openfold.utils.all_atom_multimer import get_rc_tensor
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -87,6 +89,7 @@ def compute_fape( ...@@ -87,6 +89,7 @@ def compute_fape(
target_positions: torch.Tensor, target_positions: torch.Tensor,
positions_mask: torch.Tensor, positions_mask: torch.Tensor,
length_scale: float, length_scale: float,
pair_mask: Optional[torch.Tensor] = None,
l1_clamp_distance: Optional[float] = None, l1_clamp_distance: Optional[float] = None,
eps=1e-8, eps=1e-8,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -108,6 +111,9 @@ def compute_fape( ...@@ -108,6 +111,9 @@ def compute_fape(
[*, N_pts] positions mask [*, N_pts] positions mask
length_scale: length_scale:
Length scale by which the loss is divided Length scale by which the loss is divided
pair_mask:
[*, N_frames, N_pts] mask to use for
separating intra- from inter-chain losses.
l1_clamp_distance: l1_clamp_distance:
Cutoff above which distance errors are disregarded Cutoff above which distance errors are disregarded
eps: eps:
...@@ -134,21 +140,30 @@ def compute_fape( ...@@ -134,21 +140,30 @@ def compute_fape(
normed_error = normed_error * frames_mask[..., None] normed_error = normed_error * frames_mask[..., None]
normed_error = normed_error * positions_mask[..., None, :] normed_error = normed_error * positions_mask[..., None, :]
# FP16-friendly averaging. Roughly equivalent to: if pair_mask is not None:
# normed_error = normed_error * pair_mask
# norm_factor = ( normed_error = torch.sum(normed_error, dim=(-1, -2))
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1) mask = frames_mask[..., None] * positions_mask[..., None, :] * pair_mask
# ) norm_factor = torch.sum(mask, dim=(-2, -1))
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
# normed_error = normed_error / (eps + norm_factor)
# ("roughly" because eps is necessarily duplicated in the latter) else:
normed_error = torch.sum(normed_error, dim=-1) # FP16-friendly averaging. Roughly equivalent to:
normed_error = ( #
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None] # norm_factor = (
) # torch.sum(frames_mask, dim=-1) *
normed_error = torch.sum(normed_error, dim=-1) # torch.sum(positions_mask, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1)) # )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
return normed_error return normed_error
...@@ -157,6 +172,7 @@ def backbone_loss( ...@@ -157,6 +172,7 @@ def backbone_loss(
backbone_rigid_tensor: torch.Tensor, backbone_rigid_tensor: torch.Tensor,
backbone_rigid_mask: torch.Tensor, backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor, traj: torch.Tensor,
pair_mask: Optional[torch.Tensor] = None,
use_clamped_fape: Optional[torch.Tensor] = None, use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0, clamp_distance: float = 10.0,
loss_unit_distance: float = 10.0, loss_unit_distance: float = 10.0,
...@@ -184,6 +200,7 @@ def backbone_loss( ...@@ -184,6 +200,7 @@ def backbone_loss(
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[None].get_trans(), gt_aff[None].get_trans(),
backbone_rigid_mask[None], backbone_rigid_mask[None],
pair_mask=pair_mask,
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -196,6 +213,7 @@ def backbone_loss( ...@@ -196,6 +213,7 @@ def backbone_loss(
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[None].get_trans(), gt_aff[None].get_trans(),
backbone_rigid_mask[None], backbone_rigid_mask[None],
pair_mask=pair_mask,
l1_clamp_distance=None, l1_clamp_distance=None,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -253,6 +271,7 @@ def sidechain_loss( ...@@ -253,6 +271,7 @@ def sidechain_loss(
sidechain_atom_pos, sidechain_atom_pos,
renamed_atom14_gt_positions, renamed_atom14_gt_positions,
renamed_atom14_gt_exists, renamed_atom14_gt_exists,
pair_mask=None,
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=length_scale, length_scale=length_scale,
eps=eps, eps=eps,
...@@ -266,10 +285,29 @@ def fape_loss( ...@@ -266,10 +285,29 @@ def fape_loss(
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
) -> torch.Tensor: ) -> torch.Tensor:
bb_loss = backbone_loss(
traj=out["sm"]["frames"], traj = out["sm"]["frames"]
**{**batch, **config.backbone}, asym_id = batch.get("asym_id")
) if asym_id is not None:
intra_chain_mask = (asym_id[..., None] == asym_id[..., None, :]).to(dtype=traj.dtype)
intra_chain_bb_loss = backbone_loss(
traj=traj,
pair_mask=intra_chain_mask,
**{**batch, **config.intra_chain_backbone},
)
interface_bb_loss = backbone_loss(
traj=traj,
pair_mask=1. - intra_chain_mask,
**{**batch, **config.interface_backbone},
)
weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight
+ interface_bb_loss * config.interface_backbone.weight)
else:
bb_loss = backbone_loss(
traj=traj,
**{**batch, **config.backbone},
)
weighted_bb_loss = bb_loss * config.backbone.weight
sc_loss = sidechain_loss( sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"], out["sm"]["sidechain_frames"],
...@@ -277,7 +315,7 @@ def fape_loss( ...@@ -277,7 +315,7 @@ def fape_loss(
**{**batch, **config.sidechain}, **{**batch, **config.sidechain},
) )
loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss loss = weighted_bb_loss + config.sidechain.weight * sc_loss
# Average over the batch dimension # Average over the batch dimension
loss = torch.mean(loss) loss = torch.mean(loss)
...@@ -627,6 +665,8 @@ def compute_predicted_aligned_error( ...@@ -627,6 +665,8 @@ def compute_predicted_aligned_error(
def compute_tm( def compute_tm(
logits: torch.Tensor, logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None, residue_weights: Optional[torch.Tensor] = None,
asym_id: Optional[torch.Tensor] = None,
interface: bool = False,
max_bin: int = 31, max_bin: int = 31,
no_bins: int = 64, no_bins: int = 64,
eps: float = 1e-8, eps: float = 1e-8,
...@@ -649,15 +689,25 @@ def compute_tm( ...@@ -649,15 +689,25 @@ def compute_tm(
tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum()) n = residue_weights.shape[-1]
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
if interface:
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
predicted_tm_term *= pair_mask
pair_residue_weights = pair_mask * (
residue_weights[..., None, :] * residue_weights[..., :, None]
)
denom = eps + torch.sum(pair_residue_weights, dim=-1, keepdims=True)
normed_residue_mask = pair_residue_weights / denom
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0] argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)] return per_alignment[tuple(argmax)]
def tm_loss( def tm_loss(
logits, logits,
final_affine_tensor, final_affine_tensor,
...@@ -709,7 +759,7 @@ def tm_loss( ...@@ -709,7 +759,7 @@ def tm_loss(
(resolution >= min_resolution) & (resolution <= max_resolution) (resolution >= min_resolution) & (resolution <= max_resolution)
) )
# Average over the loss dimension # Average over the batch dimension
loss = torch.mean(loss) loss = torch.mean(loss)
return loss return loss
...@@ -879,6 +929,7 @@ def between_residue_clash_loss( ...@@ -879,6 +929,7 @@ def between_residue_clash_loss(
atom14_atom_exists: torch.Tensor, atom14_atom_exists: torch.Tensor,
atom14_atom_radius: torch.Tensor, atom14_atom_radius: torch.Tensor,
residue_index: torch.Tensor, residue_index: torch.Tensor,
asym_id: Optional[torch.Tensor] = None,
overlap_tolerance_soft=1.5, overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5, overlap_tolerance_hard=1.5,
eps=1e-10, eps=1e-10,
...@@ -954,9 +1005,13 @@ def between_residue_clash_loss( ...@@ -954,9 +1005,13 @@ def between_residue_clash_loss(
) )
n_one_hot = n_one_hot.type(fp_type) n_one_hot = n_one_hot.type(fp_type)
neighbour_mask = ( neighbour_mask = (residue_index[..., :, None] + 1) == residue_index[..., None, :]
residue_index[..., :, None, None, None] + 1
) == residue_index[..., None, :, None, None] if asym_id is not None:
neighbour_mask = neighbour_mask & (asym_id[..., :, None] == asym_id[..., None, :])
neighbour_mask = neighbour_mask[..., None, None]
c_n_bonds = ( c_n_bonds = (
neighbour_mask neighbour_mask
* c_one_hot[..., None, None, :, None] * c_one_hot[..., None, None, :, None]
...@@ -998,26 +1053,29 @@ def between_residue_clash_loss( ...@@ -998,26 +1053,29 @@ def between_residue_clash_loss(
# Compute the per atom loss sum. # Compute the per atom loss sum.
# shape (N, 14) # shape (N, 14)
per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum( per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
dists_to_low_error, axis=(-3, -1) dists_to_low_error, dim=(-3, -1)
) )
# Compute the hard clash mask. # Compute the hard clash mask.
# shape (N, N, 14, 14) # shape (N, N, 14, 14)
clash_mask = dists_mask * ( clash_mask = dists_mask * (
dists < (dists_lower_bound - overlap_tolerance_hard) dists < (dists_lower_bound - overlap_tolerance_hard)
) )
per_atom_num_clash = torch.sum(clash_mask, dim=(-4, -2)) + torch.sum(clash_mask, dim=(-3, -1))
# Compute the per atom clash. # Compute the per atom clash.
# shape (N, 14) # shape (N, 14)
per_atom_clash_mask = torch.maximum( per_atom_clash_mask = torch.maximum(
torch.amax(clash_mask, axis=(-4, -2)), torch.amax(clash_mask, dim=(-4, -2)),
torch.amax(clash_mask, axis=(-3, -1)), torch.amax(clash_mask, dim=(-3, -1)),
) )
return { return {
"mean_loss": mean_loss, # shape () "mean_loss": mean_loss, # shape ()
"per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14) "per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
"per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14) "per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
"per_atom_num_clash": per_atom_num_clash # shape (N, 14)
} }
...@@ -1097,6 +1155,8 @@ def within_residue_violations( ...@@ -1097,6 +1155,8 @@ def within_residue_violations(
(dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound) (dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
) )
per_atom_num_clash = torch.sum(violations, dim=-2) + torch.sum(violations, dim=-1)
# Compute the per atom violations. # Compute the per atom violations.
per_atom_violations = torch.maximum( per_atom_violations = torch.maximum(
torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0] torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
...@@ -1105,6 +1165,7 @@ def within_residue_violations( ...@@ -1105,6 +1165,7 @@ def within_residue_violations(
return { return {
"per_atom_loss_sum": per_atom_loss_sum, "per_atom_loss_sum": per_atom_loss_sum,
"per_atom_violations": per_atom_violations, "per_atom_violations": per_atom_violations,
"per_atom_num_clash": per_atom_num_clash
} }
...@@ -1134,11 +1195,24 @@ def find_structural_violations( ...@@ -1134,11 +1195,24 @@ def find_structural_violations(
residue_constants.van_der_waals_radius[name[0]] residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types for name in residue_constants.atom_types
] ]
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius) atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
atom14_atom_radius = (
batch["atom14_atom_exists"] #TODO: Consolidate monomer/multimer modes
* atomtype_radius[batch["residx_atom14_to_atom37"]] asym_id = batch.get("asym_id")
) if asym_id is not None:
residx_atom14_to_atom37 = get_rc_tensor(
residue_constants.RESTYPE_ATOM14_TO_ATOM37, batch["aatype"]
)
atom14_atom_radius = (
batch["atom14_atom_exists"]
* atomtype_radius[residx_atom14_to_atom37.long()]
)
else:
atom14_atom_radius = (
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
)
# Compute the between residue clash loss. # Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss( between_residue_clashes = between_residue_clash_loss(
...@@ -1146,6 +1220,7 @@ def find_structural_violations( ...@@ -1146,6 +1220,7 @@ def find_structural_violations(
atom14_atom_exists=batch["atom14_atom_exists"], atom14_atom_exists=batch["atom14_atom_exists"],
atom14_atom_radius=atom14_atom_radius, atom14_atom_radius=atom14_atom_radius,
residue_index=batch["residue_index"], residue_index=batch["residue_index"],
asym_id=asym_id,
overlap_tolerance_soft=clash_overlap_tolerance, overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=clash_overlap_tolerance, overlap_tolerance_hard=clash_overlap_tolerance,
) )
...@@ -1208,6 +1283,9 @@ def find_structural_violations( ...@@ -1208,6 +1283,9 @@ def find_structural_violations(
"clashes_per_atom_clash_mask": between_residue_clashes[ "clashes_per_atom_clash_mask": between_residue_clashes[
"per_atom_clash_mask" "per_atom_clash_mask"
], # (N, 14) ], # (N, 14)
"clashes_per_atom_num_clash": between_residue_clashes[
"per_atom_num_clash"
], # (N, 14)
}, },
"within_residues": { "within_residues": {
"per_atom_loss_sum": residue_violations[ "per_atom_loss_sum": residue_violations[
...@@ -1216,6 +1294,9 @@ def find_structural_violations( ...@@ -1216,6 +1294,9 @@ def find_structural_violations(
"per_atom_violations": residue_violations[ "per_atom_violations": residue_violations[
"per_atom_violations" "per_atom_violations"
], # (N, 14), ], # (N, 14),
"per_atom_num_clash": residue_violations[
"per_atom_num_clash"
], # (N, 14)
}, },
"total_per_residue_violations_mask": per_residue_violations_mask, # (N) "total_per_residue_violations_mask": per_residue_violations_mask, # (N)
} }
...@@ -1337,15 +1418,21 @@ def compute_violation_metrics_np( ...@@ -1337,15 +1418,21 @@ def compute_violation_metrics_np(
def violation_loss( def violation_loss(
violations: Dict[str, torch.Tensor], violations: Dict[str, torch.Tensor],
atom14_atom_exists: torch.Tensor, atom14_atom_exists: torch.Tensor,
average_clashes: bool = False,
eps=1e-6, eps=1e-6,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists) num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"] per_atom_clash = (violations["between_residues"]["clashes_per_atom_loss_sum"] +
+ violations["within_residues"]["per_atom_loss_sum"] violations["within_residues"]["per_atom_loss_sum"])
)
l_clash = l_clash / (eps + num_atoms) if average_clashes:
num_clash = (violations["between_residues"]["clashes_per_atom_num_clash"] +
violations["within_residues"]["per_atom_num_clash"])
per_atom_clash = per_atom_clash / (num_clash + eps)
l_clash = torch.sum(per_atom_clash) / (eps + num_atoms)
loss = ( loss = (
violations["between_residues"]["bonds_c_n_loss_mean"] violations["between_residues"]["bonds_c_n_loss_mean"]
+ violations["between_residues"]["angles_ca_c_n_loss_mean"] + violations["between_residues"]["angles_ca_c_n_loss_mean"]
...@@ -1491,7 +1578,7 @@ def experimentally_resolved_loss( ...@@ -1491,7 +1578,7 @@ def experimentally_resolved_loss(
return loss return loss
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs):
""" """
Computes BERT-style masked MSA loss. Implements subsection 1.9.9. Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
...@@ -1503,7 +1590,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): ...@@ -1503,7 +1590,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
Masked MSA loss Masked MSA loss
""" """
errors = softmax_cross_entropy( errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=23) logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes)
) )
# FP16-friendly averaging. Equivalent to: # FP16-friendly averaging. Equivalent to:
...@@ -1524,6 +1611,64 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs): ...@@ -1524,6 +1611,64 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return loss return loss
def chain_center_of_mass_loss(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
asym_id: torch.Tensor,
clamp_distance: float = -4.0,
weight: float = 0.05,
eps: float = 1e-10
) -> torch.Tensor:
"""
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
Args:
all_atom_pred_pos:
[*, N_pts, 37, 3] All-atom predicted atom positions
all_atom_positions:
[*, N_pts, 37, 3] Ground truth all-atom positions
all_atom_mask:
[*, N_pts, 37] All-atom positions mask
asym_id:
[*, N_pts] Chain asym IDs
clamp_distance:
Cutoff above which distance errors are disregarded
weight:
Weight for loss
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True)
one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float()
def get_chain_center_of_mass(pos):
center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2)
centers = center_sum / (torch.sum(chain_pos_mask, dim=-1, keepdim=True) + eps)
return Vec3Array.from_array(centers)
pred_centers = get_chain_center_of_mass(all_atom_pred_pos) # [B, NC, 3]
true_centers = get_chain_center_of_mass(all_atom_positions) # [B, NC, 3]
pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps)
true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps)
losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2
loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :]
loss = masked_mean(loss_mask, losses, dim=(-1, -2))
return loss
class AlphaFoldLoss(nn.Module): class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement""" """Aggregation of the various losses described in the supplement"""
def __init__(self, config): def __init__(self, config):
...@@ -1576,7 +1721,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1576,7 +1721,7 @@ class AlphaFoldLoss(nn.Module):
), ),
"violation": lambda: violation_loss( "violation": lambda: violation_loss(
out["violation"], out["violation"],
**batch, **{**batch, **self.config.violation},
), ),
} }
...@@ -1586,6 +1731,12 @@ class AlphaFoldLoss(nn.Module): ...@@ -1586,6 +1731,12 @@ class AlphaFoldLoss(nn.Module):
**{**batch, **out, **self.config.tm}, **{**batch, **out, **self.config.tm},
) )
if (self.config.chain_center_of_mass.enabled):
loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss(
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.chain_center_of_mass},
)
cum_loss = 0. cum_loss = 0.
losses = {} losses = {}
for loss_name, loss_fn in loss_fns.items(): for loss_name, loss_fn in loss_fns.items():
......
...@@ -978,6 +978,16 @@ class Rigid: ...@@ -978,6 +978,16 @@ class Rigid:
""" """
return self._trans.device return self._trans.device
@property
def dtype(self) -> torch.dtype:
"""
Returns the dtype of the Rigid tensors.
Returns:
The dtype of the Rigid tensors
"""
return self._rots.dtype
def get_rots(self) -> Rotation: def get_rots(self) -> Rotation:
""" """
Getter for the rotation. Getter for the rotation.
......
...@@ -219,7 +219,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult ...@@ -219,7 +219,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult
features=batch, features=batch,
result=out, result=out,
b_factors=plddt_b_factors, b_factors=plddt_b_factors,
chain_index=chain_index, remove_leading_feature_dimension=not "multimer" in config_preset,
remark=remark, remark=remark,
parents=template_domain_names, parents=template_domain_names,
parents_chain_index=template_chain_index, parents_chain_index=template_chain_index,
......
...@@ -44,6 +44,9 @@ if( ...@@ -44,6 +44,9 @@ if(
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
from openfold.config import model_config from openfold.config import model_config
from openfold.data.tools import hhsearch, hmmsearch
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.data import templates, feature_pipeline, data_pipeline from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax import openfold.np.relax.relax as relax
...@@ -61,13 +64,19 @@ from scripts.utils import add_data_args ...@@ -61,13 +64,19 @@ from scripts.utils import add_data_args
TRACING_INTERVAL = 50 TRACING_INTERVAL = 50
def precompute_alignments(tags, seqs, alignment_dir, args): def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
for tag, seq in zip(tags, seqs): for tag, seq in zip(tags, seqs):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}") fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(alignment_dir, tag) if is_multimer:
local_alignment_dir = alignment_dir
else:
local_alignment_dir = os.path.join(
alignment_dir,
os.path.join(alignment_dir, tag),
)
if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)): if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
logger.info(f"Generating alignments for {tag}...") logger.info(f"Generating alignments for {tag}...")
...@@ -76,12 +85,11 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -76,12 +85,11 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path, uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path, mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path, bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
uniclust30_database_path=args.uniclust30_database_path, uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
no_cpus=args.cpus, no_cpus=args.cpus,
) )
alignment_runner.run( alignment_runner.run(
...@@ -118,6 +126,14 @@ def generate_feature_dict( ...@@ -118,6 +126,14 @@ def generate_feature_dict(
feature_dict = data_processor.process_fasta( feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
) )
elif "multimer" in args.config_preset:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
)
else: else:
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
fp.write( fp.write(
...@@ -137,7 +153,7 @@ def list_files_with_extensions(dir, extensions): ...@@ -137,7 +153,7 @@ def list_files_with_extensions(dir, extensions):
def main(args): def main(args):
# Create the output directory # Create the output directory
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
...@@ -148,19 +164,70 @@ def main(args): ...@@ -148,19 +164,70 @@ def main(args):
"Tracing requires that fixed_size mode be enabled in the config" "Tracing requires that fixed_size mode be enabled in the config"
) )
template_featurizer = templates.TemplateHitFeaturizer( is_multimer = "multimer" in args.config_preset
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, if(is_multimer):
max_hits=config.data.predict.max_templates, if(not args.use_precomputed_alignments):
kalign_binary_path=args.kalign_binary_path, template_searcher = hmmsearch.Hmmsearch(
release_dates_path=args.release_dates_path, binary_path=args.hmmsearch_binary_path,
obsolete_pdbs_path=args.obsolete_pdbs_path hmmbuild_binary_path=args.hmmbuild_binary_path,
) database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = None
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
else:
if(not args.use_precomputed_alignments):
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
else:
template_searcher = None
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
if(not args.use_precomputed_alignments):
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus,
)
else:
alignment_runner = None
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
if(is_multimer):
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor,
)
output_dir_base = args.output_dir output_dir_base = args.output_dir
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
...@@ -181,10 +248,19 @@ def main(args): ...@@ -181,10 +248,19 @@ def main(args):
seq_list = [] seq_list = []
for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")): for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
# Gather input sequences # Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp: fasta_path = os.path.join(args.fasta_dir, fasta_file)
with open(fasta_path, "r") as fp:
data = fp.read() data = fp.read()
tags, seqs = parse_fasta(data) tags, seqs = parse_fasta(data)
if ((not is_multimer) and len(tags) != 1):
print(
f"{fasta_path} contains more than one sequence but "
f"multimer mode is not enabled. Skipping..."
)
continue
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique" # assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags) tag = '-'.join(tags)
...@@ -208,7 +284,7 @@ def main(args): ...@@ -208,7 +284,7 @@ def main(args):
output_name = f'{output_name}_{args.output_postfix}' output_name = f'{output_name}_{args.output_postfix}'
# Does nothing if the alignments have already been computed # Does nothing if the alignments have already been computed
precompute_alignments(tags, seqs, alignment_dir, args) precompute_alignments(tags, seqs, alignment_dir, args, is_multimer)
feature_dict = feature_dicts.get(tag, None) feature_dict = feature_dicts.get(tag, None)
if(feature_dict is None): if(feature_dict is None):
...@@ -230,7 +306,7 @@ def main(args): ...@@ -230,7 +306,7 @@ def main(args):
feature_dicts[tag] = feature_dict feature_dicts[tag] = feature_dict
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', feature_dict, mode='predict', is_multimer=is_multimer
) )
processed_feature_dict = { processed_feature_dict = {
...@@ -238,8 +314,8 @@ def main(args): ...@@ -238,8 +314,8 @@ def main(args):
for k,v in processed_feature_dict.items() for k,v in processed_feature_dict.items()
} }
if(args.trace_model): if (args.trace_model):
if(rounded_seqlen > cur_tracing_interval): if (rounded_seqlen > cur_tracing_interval):
logger.info( logger.info(
f"Tracing model at {rounded_seqlen} residues..." f"Tracing model at {rounded_seqlen} residues..."
) )
......
import argparse import argparse
import logging import logging
import os import os
import string
from collections import defaultdict
from openfold.data import mmcif_parsing from openfold.data import mmcif_parsing
from openfold.np import protein, residue_constants from openfold.np import protein, residue_constants
...@@ -22,7 +23,7 @@ def main(args): ...@@ -22,7 +23,7 @@ def main(args):
if(mmcif.mmcif_object is None): if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {fname}...') logging.warning(f'Failed to parse {fname}...')
if(args.raise_errors): if(args.raise_errors):
raise list(mmcif.errors.values())[0] raise Exception(list(mmcif.errors.values())[0])
else: else:
continue continue
...@@ -31,6 +32,25 @@ def main(args): ...@@ -31,6 +32,25 @@ def main(args):
chain_id = '_'.join([basename, chain]) chain_id = '_'.join([basename, chain])
fasta.append(f">{chain_id}") fasta.append(f">{chain_id}")
fasta.append(seq) fasta.append(seq)
elif(ext == ".pdb"):
with open(fpath, 'r') as fp:
pdb_str = fp.read()
protein_object = protein.from_pdb_string(pdb_str)
aatype = protein_object.aatype
chain_index = protein_object.chain_index
last_chain_index = chain_index[0]
chain_dict = defaultdict(list)
for i in range(aatype.shape[0]):
chain_dict[chain_index[i]].append(residue_constants.restypes_with_x[aatype[i]])
chain_tags = string.ascii_uppercase
for chain, seq in chain_dict.items():
chain_id = '_'.join([basename, chain_tags[chain]])
fasta.append(f">{chain_id}")
fasta.append(''.join(seq))
elif(ext == ".core"): elif(ext == ".core"):
with open(fpath, 'r') as fp: with open(fpath, 'r') as fp:
core_str = fp.read() core_str = fp.read()
......
import copy
import os
import torch
import deepspeed
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.ml = torch.nn.ModuleList()
for _ in range(4000):
self.ml.append(torch.nn.Linear(500, 500))
def forward(self, batch):
for i, l in enumerate(self.ml):
# print(f"{i}: {l.weight.device}")
batch = l(batch)
return batch
class DummyDataset(torch.utils.data.Dataset):
def __init__(self):
self.batch = torch.rand(500, 500)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
dd = DummyDataset()
dl = torch.utils.data.DataLoader(dd)
example = next(iter(dl)).to(f"cuda:{local_rank}")
model = Model()
model = model.to(f"cuda:{local_rank}")
model = deepspeed.init_inference(
model,
mp_size=world_size,
checkpoint=None,
replace_method=None,
#replace_method="auto"
)
out = model(example)
#if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# print(out)
...@@ -56,10 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}" ...@@ -56,10 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB mmCIF files..." echo "Downloading PDB mmCIF files..."
bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniclust30..." echo "Downloading Uniref30..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_uniref30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..." echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded." echo "All data downloaded."
...@@ -31,7 +31,7 @@ fi ...@@ -31,7 +31,7 @@ fi
DOWNLOAD_DIR="$1" DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/params" ROOT_DIR="${DOWNLOAD_DIR}/params"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar" SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar"
BASENAME=$(basename "${SOURCE_URL}") BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}" mkdir --parents "${ROOT_DIR}"
......
...@@ -32,8 +32,8 @@ fi ...@@ -32,8 +32,8 @@ fi
DOWNLOAD_DIR="$1" DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/mgnify" ROOT_DIR="${DOWNLOAD_DIR}/mgnify"
# Mirror of: # Mirror of:
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/mgy_clusters.fa.gz # ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2022_05/mgy_clusters.fa.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/mgy_clusters_2018_12.fa.gz" SOURCE_URL="https://storage.googleapis.com/alphafold-databases/v2.3/mgy_clusters_2022_05.fa.gz"
BASENAME=$(basename "${SOURCE_URL}") BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}" mkdir --parents "${ROOT_DIR}"
......
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the PDB SeqRes database for AlphaFold.
#
# Usage: bash download_pdb_seqres.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb_seqres"
SOURCE_URL="ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
# Keep only protein sequences.
grep --after-context=1 --no-group-separator '>.* mol:protein' "${ROOT_DIR}/pdb_seqres.txt" > "${ROOT_DIR}/pdb_seqres_filtered.txt"
mv "${ROOT_DIR}/pdb_seqres_filtered.txt" "${ROOT_DIR}/pdb_seqres.txt"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads, unzips and merges the SwissProt and TrEMBL databases for
# AlphaFold-Multimer.
#
# Usage: bash download_uniprot.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/uniprot"
TREMBL_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz"
TREMBL_BASENAME=$(basename "${TREMBL_SOURCE_URL}")
TREMBL_UNZIPPED_BASENAME="${TREMBL_BASENAME%.gz}"
SPROT_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz"
SPROT_BASENAME=$(basename "${SPROT_SOURCE_URL}")
SPROT_UNZIPPED_BASENAME="${SPROT_BASENAME%.gz}"
mkdir --parents "${ROOT_DIR}"
aria2c "${TREMBL_SOURCE_URL}" --dir="${ROOT_DIR}"
aria2c "${SPROT_SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${TREMBL_BASENAME}"
gunzip "${ROOT_DIR}/${SPROT_BASENAME}"
# Concatenate TrEMBL and SwissProt, rename to uniprot and clean up.
cat "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}" >> "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}"
mv "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}" "${ROOT_DIR}/uniprot.fasta"
rm "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}"
popd
...@@ -30,10 +30,15 @@ if ! command -v aria2c &> /dev/null ; then ...@@ -30,10 +30,15 @@ if ! command -v aria2c &> /dev/null ; then
fi fi
DOWNLOAD_DIR="$1" DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}" ROOT_DIR="${DOWNLOAD_DIR}/uniref30"
SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz" # Mirror of:
# https://wwwuser.gwdg.de/~compbiol/uniclust/2021_03/UniRef30_2021_03.tar.gz
SOURCE_URL="https://storage.googleapis.com/alphafold-databases/v2.3/UniRef30_2021_03.tar.gz"
BASENAME=$(basename "${SOURCE_URL}") BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}" mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" -x 4 --check-certificate=false aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" -x 4 --check-certificate=false
gunzip "${ROOT_DIR}/${BASENAME}" tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
--directory="${ROOT_DIR}"
rm "${ROOT_DIR}/${BASENAME}"
...@@ -26,7 +26,12 @@ mkdir -p "${ALIGNMENT_DIR}" ...@@ -26,7 +26,12 @@ mkdir -p "${ALIGNMENT_DIR}"
for chain_dir in $(ls "${RODA_DIR}"); do for chain_dir in $(ls "${RODA_DIR}"); do
CHAIN_DIR_PATH="${RODA_DIR}/${chain_dir}" CHAIN_DIR_PATH="${RODA_DIR}/${chain_dir}"
for subdir in $(ls "${CHAIN_DIR_PATH}"); do for subdir in $(ls "${CHAIN_DIR_PATH}"); do
if [[ $subdir = "pdb" ]] || [[ $subdir = "cif" ]]; then if [[ ! -d "$subdir" ]]; then
echo "$subdir is not directory"
continue
elif [[ -z $(ls "${subdir}")]]; then
continue
elif [[ $subdir = "pdb" ]] || [[ $subdir = "cif" ]]; then
mv "${CHAIN_DIR_PATH}/${subdir}"/* "${DATA_DIR}" mv "${CHAIN_DIR_PATH}/${subdir}"/* "${DATA_DIR}"
else else
CHAIN_ALIGNMENT_DIR="${ALIGNMENT_DIR}/${chain_dir}" CHAIN_ALIGNMENT_DIR="${ALIGNMENT_DIR}/${chain_dir}"
......
...@@ -2,35 +2,62 @@ import argparse ...@@ -2,35 +2,62 @@ import argparse
import os import os
import pickle import pickle
from alphafold.data import pipeline, templates from alphafold.data import pipeline, pipeline_multimer, templates
from alphafold.data.tools import hmmsearch, hhsearch
from scripts.utils import add_data_args from scripts.utils import add_data_args
def main(args): def main(args):
template_featurizer = templates.TemplateHitFeaturizer( if (args.multimer):
mmcif_dir=args.mmcif_dir, template_searcher = hmmsearch.Hmmsearch(
max_template_date=args.max_template_date, binary_path=args.hmmsearch_binary_path,
max_hits=20, hmmbuild_binary_path=args.hmmbuild_binary_path,
kalign_binary_path=args.kalign_binary_path, database_path=args.pdb_seqres_database_path,
release_dates_path=None, )
obsolete_pdbs_path=args.obsolete_pdbs_path,
) template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=20,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
else:
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=20,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=None,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
data_pipeline = pipeline.DataPipeline( data_pipeline = pipeline.DataPipeline(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path, uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path, mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path, bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path, uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
small_bfd_database_path=None, small_bfd_database_path=None,
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
template_searcher=template_searcher,
use_small_bfd=False, use_small_bfd=False,
) )
if (args.multimer):
data_pipeline = pipeline_multimer.DataPipeline(
monomer_data_pipeline=data_pipeline,
jackhmmer_binary_path=args.jackhmmer_binary_path,
uniprot_database_path=args.uniprot_database_path)
feature_dict = data_pipeline.process( feature_dict = data_pipeline.process(
input_fasta_path=args.fasta_path, input_fasta_path=args.fasta_path,
msa_output_dir=args.output_dir, msa_output_dir=args.output_dir,
...@@ -44,6 +71,7 @@ if __name__ == "__main__": ...@@ -44,6 +71,7 @@ if __name__ == "__main__":
parser.add_argument("fasta_path", type=str) parser.add_argument("fasta_path", type=str)
parser.add_argument("mmcif_dir", type=str) parser.add_argument("mmcif_dir", type=str)
parser.add_argument("output_dir", type=str) parser.add_argument("output_dir", type=str)
parser.add_argument("--multimer", action='store_true')
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -4,10 +4,11 @@ import json ...@@ -4,10 +4,11 @@ import json
import logging import logging
from multiprocessing import Pool from multiprocessing import Pool
import os import os
import string
import sys import sys
sys.path.append(".") # an innocent hack to get this to run from the top level sys.path.append(".") # an innocent hack to get this to run from the top level
from collections import defaultdict
from tqdm import tqdm from tqdm import tqdm
from openfold.data.mmcif_parsing import parse from openfold.data.mmcif_parsing import parse
...@@ -49,20 +50,27 @@ def parse_file( ...@@ -49,20 +50,27 @@ def parse_file(
pdb_string = fp.read() pdb_string = fp.read()
protein_object = protein.from_pdb_string(pdb_string, None) protein_object = protein.from_pdb_string(pdb_string, None)
aatype = protein_object.aatype
chain_index = protein_object.chain_index
chain_dict = {} chain_dict = defaultdict(list)
chain_dict["seq"] = residue_constants.aatype_to_str_sequence( for i in range(aatype.shape[0]):
protein_object.aatype, chain_dict[chain_index[i]].append(residue_constants.restypes_with_x[aatype[i]])
)
chain_dict["resolution"] = 0.
if(chain_cluster_size_dict is not None):
cluster_size = chain_cluster_size_dict.get(
full_name.upper(), -1
)
chain_dict["cluster_size"] = cluster_size
out = {file_id: chain_dict} out = {}
chain_tags = string.ascii_uppercase
for chain, seq in chain_dict.items():
full_name = "_".join([file_id, chain_tags[chain]])
out[full_name] = {}
local_data = out[full_name]
local_data["resolution"] = 0.
local_data["seq"] = ''.join(seq)
if(chain_cluster_size_dict is not None):
cluster_size = chain_cluster_size_dict.get(
full_name.upper(), -1
)
local_data["cluster_size"] = cluster_size
return out return out
......
...@@ -11,6 +11,7 @@ import tempfile ...@@ -11,6 +11,7 @@ import tempfile
import openfold.data.mmcif_parsing as mmcif_parsing import openfold.data.mmcif_parsing as mmcif_parsing
from openfold.data.data_pipeline import AlignmentRunner from openfold.data.data_pipeline import AlignmentRunner
from openfold.data.parsers import parse_fasta from openfold.data.parsers import parse_fasta
from openfold.data.tools import hhsearch, hmmsearch
from openfold.np import protein, residue_constants from openfold.np import protein, residue_constants
from utils import add_data_args from utils import add_data_args
...@@ -39,7 +40,8 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args): ...@@ -39,7 +40,8 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
alignment_runner.run( alignment_runner.run(
fasta_path, alignment_dir fasta_path, alignment_dir
) )
except: except Exception as e:
logging.warning(e)
logging.warning(f"Failed to run alignments for {first_name}. Skipping...") logging.warning(f"Failed to run alignments for {first_name}. Skipping...")
os.remove(fasta_path) os.remove(fasta_path)
os.rmdir(alignment_dir) os.rmdir(alignment_dir)
...@@ -114,15 +116,30 @@ def parse_and_align(files, alignment_runner, args): ...@@ -114,15 +116,30 @@ def parse_and_align(files, alignment_runner, args):
def main(args): def main(args):
# Build the alignment tool runner # Build the alignment tool runner
if (args.hmmsearch_binary_path is not None):
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
elif (args.hhsearch_binary_path is not None):
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
else:
template_searcher = None
alignment_runner = AlignmentRunner( alignment_runner = AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path, uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path, mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path, bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
uniclust30_database_path=args.uniclust30_database_path, uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path, uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=args.bfd_database_path is None, use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus_per_task, no_cpus=args.cpus_per_task,
) )
......
...@@ -14,9 +14,18 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -14,9 +14,18 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
'--pdb70_database_path', type=str, default=None, '--pdb70_database_path', type=str, default=None,
) )
parser.add_argument(
'--pdb_seqres_database_path', type=str, default=None,
)
parser.add_argument(
'--uniref30_database_path', type=str, default=None,
)
parser.add_argument( parser.add_argument(
'--uniclust30_database_path', type=str, default=None, '--uniclust30_database_path', type=str, default=None,
) )
parser.add_argument(
'--uniprot_database_path', type=str, default=None,
)
parser.add_argument( parser.add_argument(
'--bfd_database_path', type=str, default=None, '--bfd_database_path', type=str, default=None,
) )
...@@ -29,6 +38,12 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -29,6 +38,12 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch' '--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
) )
parser.add_argument(
'--hmmsearch_binary_path', type=str, default='/usr/bin/hmmsearch'
)
parser.add_argument(
'--hmmbuild_binary_path', type=str, default='/usr/bin/hmmbuild'
)
parser.add_argument( parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign' '--kalign_binary_path', type=str, default='/usr/bin/kalign'
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment