Commit a1c29028 authored by zhangqha's avatar zhangqha
Browse files

update uni-fold

parents
Pipeline #183 canceled with stages
import torch
from unicore.utils import one_hot
from unifold.data import residue_constants as rc
from .utils import (
sigmoid_cross_entropy,
softmax_cross_entropy,
masked_mean,
)
from .geometry import (
compute_aligned_error,
compute_distogram,
compute_lddt,
)
def experimentally_resolved_loss(
logits: torch.Tensor,
atom37_atom_exists: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
min_resolution: float,
max_resolution: float,
eps: float = 1e-8,
loss_dict: dict = None,
**kwargs,
) -> torch.Tensor:
atom37_atom_exists = atom37_atom_exists.float()
all_atom_mask = all_atom_mask.float()
errors = sigmoid_cross_entropy(logits.float(), all_atom_mask)
loss = torch.sum(errors * atom37_atom_exists, dim=-1)
dnorm = torch.sum(atom37_atom_exists, dim=(-1, -2)).unsqueeze(-1)
loss = loss / (eps + dnorm)
loss = torch.sum(loss, dim=-1)
loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution))
loss_dict["experimentally_resolved"] = loss.data
return loss
def plddt_loss(
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0,
num_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
loss_dict: dict = None,
**kwargs,
) -> torch.Tensor:
# TODO: bin utils
ca_pos = rc.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :].float()
all_atom_positions = all_atom_positions[..., ca_pos, :].float()
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)].float() # keep dim
lddt = compute_lddt(
all_atom_pred_pos, all_atom_positions, all_atom_mask, cutoff=cutoff, eps=eps
).detach()
bin_index = torch.floor(lddt * num_bins).long()
bin_index = torch.clamp(bin_index, max=(num_bins - 1))
lddt_ca_one_hot = one_hot(bin_index, num_classes=num_bins)
errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
all_atom_mask = all_atom_mask.squeeze(-1)
loss = masked_mean(all_atom_mask, errors, dim=-1, eps=eps)
loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution))
ca_lddt = masked_mean(all_atom_mask, lddt, dim=-1, eps=eps)
loss_dict["ca_lddt_score"] = ca_lddt.data
loss_dict["plddt_loss"] = loss.data
return loss
def supervised_chi_loss(
pred_angles_sin_cos: torch.Tensor,
pred_unnormed_angles_sin_cos: torch.Tensor,
true_angles_sin_cos: torch.Tensor,
aatype: torch.Tensor,
seq_mask: torch.Tensor,
chi_mask: torch.Tensor,
chi_weight: float,
angle_norm_weight: float,
eps=1e-6,
loss_dict=None,
**kwargs,
) -> torch.Tensor:
# TODO: refactor this.
pred_angles_sin_cos = pred_angles_sin_cos.float()
pred_unnormed_angles_sin_cos = pred_unnormed_angles_sin_cos.float()
true_angles_sin_cos = true_angles_sin_cos.unsqueeze(0).float()
seq_mask = seq_mask.float()
chi_mask = chi_mask.float()
pred_angles = pred_angles_sin_cos[..., 3:, :]
residue_type_one_hot = one_hot(
aatype,
rc.restype_num + 1,
)
chi_pi_periodic = torch.einsum(
"ijk, kl->ijl",
residue_type_one_hot.type(pred_angles_sin_cos.dtype),
pred_angles_sin_cos.new_tensor(rc.chi_pi_periodic),
)
true_chi = true_angles_sin_cos
shifted_mask = (1.0 - 2.0 * chi_pi_periodic)[None, ..., None]
true_chi_shifted = shifted_mask * true_chi
sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
sq_chi_error_shifted = torch.sum((true_chi_shifted - pred_angles) ** 2, dim=-1)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
# permute nblock and batch dim
sq_chi_error = sq_chi_error.transpose(0, 1)
mask = chi_mask.unsqueeze(1)
sq_chi_loss = masked_mean(mask, sq_chi_error, dim=(-1, -2, -3))
loss_dict["chi_loss"] = sq_chi_loss.data
loss = chi_weight * sq_chi_loss
angle_norm = torch.sqrt(torch.sum(pred_unnormed_angles_sin_cos**2, dim=-1) + eps)
norm_error = torch.abs(angle_norm - 1.0)
norm_error = norm_error.transpose(0, 1)
mask = seq_mask[..., None, :, None]
angle_norm_loss = masked_mean(mask, norm_error, dim=(-1, -2, -3))
loss_dict["angle_norm_loss"] = angle_norm_loss.data
loss = loss + angle_norm_weight * angle_norm_loss
return loss
def repr_norm_loss(
msa_norm: torch.Tensor,
pair_norm: torch.Tensor,
msa_mask: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
loss_dict=None,
eps=1e-5,
tolerance=0.0,
**kwargs,
) -> torch.Tensor:
def norm_loss(x):
max_norm = x.shape[-1] ** 0.5
norm = torch.sqrt(torch.sum(x**2, dim=-1) + eps)
error = torch.nn.functional.relu((norm - max_norm).abs() - tolerance)
return error
pair_norm_error = norm_loss(pair_norm.float())
msa_norm_error = norm_loss(msa_norm.float())
pair_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
pair_norm_loss = masked_mean(pair_mask.float(), pair_norm_error, dim=(-1, -2))
msa_norm_loss = masked_mean(msa_mask.float(), msa_norm_error, dim=(-1, -2))
loss = pair_norm_loss + msa_norm_loss
loss_dict["pair_norm"] = pair_norm_loss.data
loss_dict["msa_norm"] = msa_norm_loss.data
return loss
def distogram_loss(
logits,
pseudo_beta,
pseudo_beta_mask,
min_bin=2.3125,
max_bin=21.6875,
num_bins=64,
eps=1e-6,
loss_dict=None,
**kwargs,
):
distogram, mask = compute_distogram(
pseudo_beta, pseudo_beta_mask, min_bin, max_bin, num_bins)
errors = softmax_cross_entropy(logits, one_hot(distogram, num_bins))
loss = masked_mean(mask, errors, dim=(-1, -2), eps=eps)
loss_dict["distogram"] = loss.data
return loss
def pae_loss(
logits,
pred_frame_tensor,
true_frame_tensor,
frame_mask,
resolution,
max_bin=31,
num_bins=64,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps=1e-8,
loss_dict=None,
**kwargs,
):
aligned_error_val, aligned_error_bin, mask = compute_aligned_error(
pred_frame_tensor,
true_frame_tensor,
frame_mask,
max_bin,
num_bins,
)
errors = softmax_cross_entropy(logits.float(), one_hot(aligned_error_bin, num_bins))
loss = masked_mean(mask, errors, dim=(-1, -2), eps=eps)
loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution))
loss_dict["pae_loss"] = loss.data
loss_dict["aligned_error"] = aligned_error_val.data
return loss
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, loss_dict=None, **kwargs):
bert_mask = bert_mask.float()
errors = softmax_cross_entropy(
logits.float(), one_hot(true_msa.long(), num_classes=logits.shape[-1])
)
loss = masked_mean(bert_mask, errors, dim=(-1, -2), eps=eps)
loss_dict["masked_msa"] = loss.data
return loss
def get_asym_mask(asym_id):
"""get the mask for each asym_id. [*, NR] -> [*, NC, NR]"""
# this func presumes that valid asym_id ranges [1, NC] and is dense.
asym_type = torch.arange(1, torch.amax(asym_id) + 1, device=asym_id.device) # [NC]
return (asym_id[..., None, :] == asym_type[:, None]).float()
def chain_centre_mass_loss(
pred_atom_positions: torch.Tensor,
true_atom_positions: torch.Tensor,
atom_mask: torch.Tensor,
asym_id: torch.Tensor,
eps: float = 1e-10,
loss_dict=None,
**kwargs,
) -> torch.Tensor:
ca_pos = rc.atom_order["CA"]
pred_atom_positions = pred_atom_positions[..., ca_pos, :].float() # [B, NR, 3]
true_atom_positions = true_atom_positions[..., ca_pos, :].float() # [B, NR, 3]
atom_mask = atom_mask[..., ca_pos].bool() # [B, NR]
assert len(pred_atom_positions.shape) == 3
asym_mask = get_asym_mask(asym_id) * atom_mask[..., None, :] # [B, NC, NR]
asym_exists = torch.any(asym_mask, dim=-1).float() # [B, NC]
def get_asym_centres(pos):
pos = pos[..., None, :, :] * asym_mask[..., :, :, None] # [B, NC, NR, 3]
return torch.sum(pos, dim=-2) / (torch.sum(asym_mask, dim=-1)[..., None] + eps)
pred_centres = get_asym_centres(pred_atom_positions) # [B, NC, 3]
true_centres = get_asym_centres(true_atom_positions) # [B, NC, 3]
def get_dist(p1: torch.Tensor, p2: torch.Tensor):
return torch.sqrt(
(p1[..., :, None, :] - p2[..., None, :, :]).square().sum(-1) + eps
)
pred_centres2 = pred_centres
true_centres2 = true_centres
pred_dists = get_dist(pred_centres, pred_centres2) # [B, NC, NC]
true_dists = get_dist(true_centres, true_centres2) # [B, NC, NC]
losses = (pred_dists - true_dists + 4).clamp(max=0).square() * 0.0025
loss_mask = asym_exists[..., :, None] * asym_exists[..., None, :] # [B, NC, NC]
loss = masked_mean(loss_mask, losses, dim=(-1, -2))
loss_dict["chain_centre_loss"] = loss.data
return loss
import torch
from unifold.data import residue_constants as rc
from .geometry import kabsch_rmsd, get_optimal_transform, compute_rmsd
def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :].float() # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].float() # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :].float() for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].float() for l in labels
] # list([nres,])
unique_asym_ids = torch.unique(batch["asym_id"])
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
anchor_gt_asym, anchor_pred_asym = get_anchor_candidates(
batch, per_asym_residue_index, true_ca_masks
)
anchor_gt_idx = int(anchor_gt_asym) - 1
best_rmsd = 1e9
best_labels = None
unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
for cur_asym_id in anchor_pred_asym:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)]
anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
anchor_pred_mask = pred_ca_mask[asym_mask]
r, x = get_optimal_transform(
anchor_true_pos,
anchor_pred_pos,
(anchor_true_mask * anchor_pred_mask).bool(),
)
aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
for _ in range(shuffle_times):
shuffle_idx = torch.randperm(
unique_asym_ids.shape[0], device=unique_asym_ids.device
)
shuffled_asym_ids = unique_asym_ids[shuffle_idx]
align = greedy_align(
batch,
per_asym_residue_index,
shuffled_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
merged_labels = merge_labels(
batch,
per_asym_residue_index,
labels,
align,
)
rmsd = kabsch_rmsd(
merged_labels["all_atom_positions"][..., ca_idx, :] @ r + x,
pred_ca_pos,
(pred_ca_mask * merged_labels["all_atom_mask"][..., ca_idx]).bool(),
)
if rmsd < best_rmsd:
best_rmsd = rmsd
best_labels = merged_labels
return best_labels
def get_anchor_candidates(batch, per_asym_residue_index, true_masks):
def find_by_num_sym(min_num_sym):
best_len = -1
best_gt_asym = None
asym_ids = torch.unique(batch["asym_id"][batch["num_sym"] == min_num_sym])
for cur_asym_id in asym_ids:
assert cur_asym_id > 0
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
j = int(cur_asym_id - 1)
cur_true_mask = true_masks[j][cur_residue_index]
cur_len = cur_true_mask.sum()
if cur_len > best_len:
best_len = cur_len
best_gt_asym = cur_asym_id
return best_gt_asym, best_len
sorted_num_sym = batch["num_sym"][batch["num_sym"] > 0].sort()[0]
best_gt_asym = None
best_len = -1
for cur_num_sym in sorted_num_sym:
if cur_num_sym <= 0:
continue
cur_gt_sym, cur_len = find_by_num_sym(cur_num_sym)
if cur_len > best_len:
best_len = cur_len
best_gt_asym = cur_gt_sym
if best_len >= 3:
break
best_entity = batch["entity_id"][batch["asym_id"] == best_gt_asym][0]
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == best_entity])
return best_gt_asym, best_pred_asym
def greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
true_ca_poses,
true_ca_masks,
):
used = [False for _ in range(len(true_ca_poses))]
align = []
for cur_asym_id in unique_asym_ids:
# skip padding
if cur_asym_id == 0:
continue
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
num_sym = batch["num_sym"][asym_mask][0]
# don't need to align
if (num_sym) == 1:
align.append((i, i))
assert used[i] == False
used[i] = True
continue
cur_entity_ids = batch["entity_id"][asym_mask][0]
best_rmsd = 1e10
best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list:
if next_asym_id == 0:
continue
j = int(next_asym_id - 1)
if not used[j]: # posesible candidate
cropped_pos = true_ca_poses[j][cur_residue_index]
mask = true_ca_masks[j][cur_residue_index]
rmsd = compute_rmsd(
cropped_pos, cur_pred_pos, (cur_pred_mask * mask).bool()
)
if rmsd < best_rmsd:
best_rmsd = rmsd
best_idx = j
assert best_idx is not None
used[best_idx] = True
align.append((i, best_idx))
return align
def merge_labels(batch, per_asym_residue_index, labels, align):
"""
batch:
labels: list of label dicts, each with shape [nk, *]
align: list of int, such as [2, None, 0, 1], each entry specify the corresponding label of the asym.
"""
num_res = batch["msa_mask"].shape[-1]
outs = {}
for k, v in labels[0].items():
if k in [
"resolution",
]:
continue
cur_out = {}
for i, j in align:
label = labels[j][k]
# to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
cur_out[i] = label[cur_residue_index]
cur_out = [x[1] for x in sorted(cur_out.items())]
new_v = torch.concat(cur_out, dim=0)
merged_nres = new_v.shape[0]
assert (
merged_nres <= num_res
), f"bad merged num res: {merged_nres} > {num_res}. something is wrong."
if merged_nres < num_res: # must pad
pad_dim = new_v.shape[1:]
pad_v = new_v.new_zeros((num_res - merged_nres, *pad_dim))
new_v = torch.concat((new_v, pad_v), dim=0)
outs[k] = new_v
return outs
import ml_collections
import torch
from typing import Dict
from .geometry import compute_fape
from unifold.modules.frame import Frame
def backbone_loss(
true_frame_tensor: torch.Tensor,
frame_mask: torch.Tensor,
traj: torch.Tensor,
use_clamped_fape: torch.Tensor,
clamp_distance: float = 10.0,
loss_unit_distance: float = 10.0,
clamp_distance_between_chains: float = 30.0,
loss_unit_distance_between_chains: float = 20.0,
intra_chain_mask: torch.Tensor = None,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
pred_aff = Frame.from_tensor_4x4(traj)
gt_aff = Frame.from_tensor_4x4(true_frame_tensor)
use_clamped_fape = int(use_clamped_fape) == 1
if intra_chain_mask is None:
return compute_fape(
pred_aff,
gt_aff[None],
frame_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
frame_mask[None],
pair_mask=None,
l1_clamp_distance=clamp_distance if use_clamped_fape else None,
length_scale=loss_unit_distance,
eps=eps,
)
else:
intra_chain_mask = intra_chain_mask.float().unsqueeze(0)
intra_chain_bb_loss = compute_fape(
pred_aff,
gt_aff[None],
frame_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
frame_mask[None],
pair_mask=intra_chain_mask,
l1_clamp_distance=clamp_distance if use_clamped_fape else None,
length_scale=loss_unit_distance,
eps=eps,
)
interface_fape = compute_fape(
pred_aff,
gt_aff[None],
frame_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
frame_mask[None],
pair_mask=1.0 - intra_chain_mask,
l1_clamp_distance=clamp_distance_between_chains
if use_clamped_fape
else None,
length_scale=loss_unit_distance_between_chains,
eps=eps,
)
return intra_chain_bb_loss, interface_fape
def sidechain_loss(
sidechain_frames: torch.Tensor,
sidechain_atom_pos: torch.Tensor,
rigidgroups_gt_frames: torch.Tensor,
rigidgroups_alt_gt_frames: torch.Tensor,
rigidgroups_gt_exists: torch.Tensor,
renamed_atom14_gt_positions: torch.Tensor,
renamed_atom14_gt_exists: torch.Tensor,
alt_naming_is_better: torch.Tensor,
clamp_distance: float = 10.0,
length_scale: float = 10.0,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
renamed_gt_frames = (
1.0 - alt_naming_is_better[..., None, None, None]
) * rigidgroups_gt_frames + alt_naming_is_better[
..., None, None, None
] * rigidgroups_alt_gt_frames
batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = Frame.from_tensor_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = Frame.from_tensor_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(*batch_dims, -1, 3)
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
fape = compute_fape(
sidechain_frames,
renamed_gt_frames,
rigidgroups_gt_exists,
sidechain_atom_pos,
renamed_atom14_gt_positions,
renamed_atom14_gt_exists,
pair_mask=None,
l1_clamp_distance=clamp_distance,
length_scale=length_scale,
eps=eps,
)
return fape
def fape_loss(
out: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict,
loss_dict: dict,
) -> torch.Tensor:
for key in out["sm"]:
out["sm"][key] = out["sm"][key].float()
if "asym_id" in batch:
intra_chain_mask = (
batch["asym_id"][..., :, None] == batch["asym_id"][..., None, :]
)
bb_loss, interface_loss = backbone_loss(
traj=out["sm"]["frames"],
**{**batch, **config.backbone},
intra_chain_mask=intra_chain_mask,
)
# only show the loss on last layer
loss_dict["fape"] = bb_loss[-1].data
loss_dict["interface_fape"] = interface_loss[-1].data
bb_loss = torch.mean(bb_loss, dim=0) + torch.mean(interface_loss, dim=0)
else:
bb_loss = backbone_loss(
traj=out["sm"]["frames"],
**{**batch, **config.backbone},
intra_chain_mask=None,
)
# only show the loss on last layer
loss_dict["fape"] = bb_loss[-1].data
bb_loss = torch.mean(bb_loss, dim=0)
sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"],
out["sm"]["positions"],
**{**batch, **config.sidechain},
)
loss_dict["sc_fape"] = sc_loss.data
loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
return loss
import torch
from unifold.data import residue_constants as rc
from unifold.modules.frame import Frame
from typing import Dict, Tuple
from unicore.utils import (
permute_final_dims,
set_jit_fusion_options,
)
set_jit_fusion_options()
def compute_lddt(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
dmat_true = torch.sqrt(
eps
+ torch.sum(
(all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])
** 2,
dim=-1,
)
)
dmat_pred = torch.sqrt(
eps
+ torch.sum(
(all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2,
dim=-1,
)
)
dists_to_score = (
(dmat_true < cutoff)
* all_atom_mask
* permute_final_dims(all_atom_mask, (1, 0))
* (1.0 - torch.eye(n, device=all_atom_mask.device))
)
dist_l1 = torch.abs(dmat_true - dmat_pred)
score = (
(dist_l1 < 0.5).type(dist_l1.dtype)
+ (dist_l1 < 1.0).type(dist_l1.dtype)
+ (dist_l1 < 2.0).type(dist_l1.dtype)
+ (dist_l1 < 4.0).type(dist_l1.dtype)
)
score = score * 0.25
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=-1))
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1))
return score
def compute_fape(
pred_frames: Frame,
target_frames: Frame,
frames_mask: torch.Tensor,
pred_positions: torch.Tensor,
target_positions: torch.Tensor,
positions_mask: torch.Tensor,
pair_mask: torch.Tensor,
length_scale: float,
l1_clamp_distance: float,
eps: float = 1e-4,
) -> torch.Tensor:
local_pred_pos = pred_frames.invert()[..., None].apply(
pred_positions[..., None, :, :].float(),
)
local_target_pos = target_frames.invert()[..., None].apply(
target_positions[..., None, :, :].float(),
)
frames_mask = frames_mask.float()
positions_mask = positions_mask.float()
error_dist = torch.sqrt(
torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
)
if l1_clamp_distance is not None:
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
normed_error = error_dist / length_scale
normed_error *= frames_mask[..., None]
normed_error *= positions_mask[..., None, :]
if pair_mask is not None:
normed_error *= pair_mask
if pair_mask is not None:
mask = frames_mask.unsqueeze(-1) * positions_mask.unsqueeze(-2)
mask *= pair_mask
norm_factor = mask.sum(dim=(-1, -2))
else:
norm_factor = torch.sum(frames_mask, dim=-1) * torch.sum(positions_mask, dim=-1)
normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
return normed_error
def compute_distogram(
positions,
mask,
min_bin=2.3125,
max_bin=21.6875,
num_bins=64,
):
boundaries = torch.linspace(
min_bin,
max_bin,
num_bins - 1,
device=positions.device,
)
boundaries = boundaries**2
positions = positions.float()
dists = torch.sum(
(positions[..., None, :] - positions[..., None, :, :]) ** 2,
dim=-1,
keepdims=True,
).detach()
mask = mask.float()
pair_mask = mask[..., None] * mask[..., None, :]
return torch.sum(dists > boundaries, dim=-1), pair_mask
def compute_aligned_error(
pred_affine_tensor: torch.Tensor,
true_affine_tensor: torch.Tensor,
affine_mask: torch.Tensor,
max_bin: int = 31,
num_bins: int = 64,
eps: float = 1e-10,
):
pred_affine = Frame.from_tensor_4x4(pred_affine_tensor.float())
true_affine = Frame.from_tensor_4x4(true_affine_tensor.float())
def _points(affine):
pts = affine.get_trans()[..., None, :, :]
return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum(
(_points(pred_affine) - _points(true_affine)) ** 2, dim=-1
).detach()
boundaries = torch.linspace(
0, max_bin, steps=(num_bins - 1), device=pred_affine_tensor.device
)
boundaries = boundaries**2
affine_mask = affine_mask.float()
pair_mask = affine_mask[..., None] * affine_mask[..., None, :]
return (
torch.sqrt(sq_diff + eps),
torch.sum(sq_diff[..., None] > boundaries, dim=-1),
pair_mask,
)
def compute_renamed_ground_truth(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
eps=1e-10,
) -> Dict[str, torch.Tensor]:
atom14_pred_positions = atom14_pred_positions.float()
pred_dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_pred_positions[..., None, :, None, :]
- atom14_pred_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
atom14_gt_positions = batch["atom14_gt_positions"].float()
gt_dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_gt_positions[..., None, :, None, :]
- atom14_gt_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
atom14_alt_gt_positions = batch["atom14_alt_gt_positions"].float()
alt_gt_dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_alt_gt_positions[..., None, :, None, :]
- atom14_alt_gt_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
atom14_gt_exists = batch["atom14_gt_exists"].float()
atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"].float()
mask = (
atom14_gt_exists[..., None, :, None]
* atom14_atom_is_ambiguous[..., None, :, None]
* atom14_gt_exists[..., None, :, None, :]
* (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
)
per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3))
fp_type = atom14_pred_positions.dtype
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
renamed_atom14_gt_positions = (
1.0 - alt_naming_is_better[..., None, None]
) * atom14_gt_positions + alt_naming_is_better[
..., None, None
] * atom14_alt_gt_positions
renamed_atom14_gt_mask = (
1.0 - alt_naming_is_better[..., None]
) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
"atom14_alt_gt_exists"
].float()
return {
"alt_naming_is_better": alt_naming_is_better,
"renamed_atom14_gt_positions": renamed_atom14_gt_positions,
"renamed_atom14_gt_exists": renamed_atom14_gt_mask,
}
@torch.jit.script
def compute_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
# atom_mask: torch.Tensor = None,
atom_mask: torch.BoolTensor,
eps: float = 1e-6,
) -> torch.Tensor:
# shape check
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
if atom_mask is not None:
sq_diff = sq_diff[atom_mask]
msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps)
@torch.jit.script
def kabsch_rotation(P, Q):
"""
Using the Kabsch algorithm with two sets of paired point P and Q, centered
around the centroid. Each vector set is represented as an NxD
matrix, where D is the the dimension of the space.
The algorithm works in three steps:
- a centroid translation of P and Q (assumed done before this function
call)
- the computation of a covariance matrix C
- computation of the optimal rotation matrix U
For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm
Parameters
----------
P : array
(N,D) matrix, where N is points and D is dimension.
Q : array
(N,D) matrix, where N is points and D is dimension.
Returns
-------
U : matrix
Rotation matrix (D,D)
"""
# Computation of the covariance matrix
C = P.transpose(-1, -2) @ Q
# Computation of the optimal rotation matrix
# This can be done using singular value decomposition (SVD)
# Getting the sign of the det(V)*(W) to decide
# whether we need to correct our rotation matrix to ensure a
# right-handed coordinate system.
# And finally calculating the optimal rotation matrix U
# see http://en.wikipedia.org/wiki/Kabsch_algorithm
V, _, W = torch.linalg.svd(C)
d = (torch.linalg.det(V) * torch.linalg.det(W)) < 0.0
if d:
V[:, -1] = -V[:, -1]
# Create Rotation matrix U
U = V @ W
return U
@torch.jit.script
def get_optimal_transform(
src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor,
# mask: torch.Tensor = None,
mask: torch.BoolTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3
if mask is not None:
assert mask.dtype == torch.bool
assert mask.shape[-1] == src_atoms.shape[-2]
if mask.sum() == 0:
src_atoms = torch.zeros((1, 3), device=src_atoms.device).float()
tgt_atoms = src_atoms
else:
src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms[mask, :]
src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True)
r = kabsch_rotation(src_atoms - src_center, tgt_atoms - tgt_center)
x = tgt_center - src_center @ r
return r, x
@torch.jit.script
def kabsch_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor,
):
r, x = get_optimal_transform(
true_atom_pos,
pred_atom_pos,
atom_mask,
)
aligned_true_atom_pos = true_atom_pos @ r + x
return compute_rmsd(aligned_true_atom_pos, pred_atom_pos, atom_mask)
def get_optimal_transform_v2(
p: torch.Tensor,
q: torch.Tensor,
m: torch.Tensor,
num_dim: int = 1,
eps: float = 1e-6,
) -> torch.Tensor:
"""
calculate u such that p @ u ~ q.
p, q has shape [*, *dim, 3]
mask has shape [*, *dim]
ret has shape [*, *dim, 3, 3]
"""
rd = p.shape[-1]
batch_shape = p.shape[: -(num_dim + 1)]
m = m.reshape(*batch_shape, -1, 1)
def process_input(p):
p = p.reshape(*batch_shape, -1, rd)
p = p * m
cp = p.sum(dim=-2, keepdim=True) / (
eps + m.sum(dim=-2, keepdim=True)
) # [*, 1, 3]
p = p - cp
p = p * m
return p, cp
p_rc, cp = process_input(p) # rc for remove center
q_rc, cq = process_input(q)
c = p_rc.transpose(-1, -2) @ q_rc # [*, 3, 3]
v, _, w = torch.linalg.svd(c) # [*, 3, 3]
d = (torch.linalg.det(v) * torch.linalg.det(w) >= 0.0).type(
v.dtype
) * 2.0 - 1.0 # [*]
v[..., -1] = v[..., -1] * d[..., None] # [*, 3]
u = v @ w # [*, 3, 3]
u = u.reshape(*batch_shape, *((1,) * num_dim), rd, rd)
cp = cp.reshape(*batch_shape, *((1,) * num_dim), rd)
cq = cq.reshape(*batch_shape, *((1,) * num_dim), rd)
x = cq[..., None, :] - cp[..., None, :] @ u
return u, x.squeeze(-2)
def apply_optimal_transform_v2(x, r, t):
return (x.unsqueeze(-2) @ r + t.unsqueeze(-2)).squeeze(-2)
def compute_rmsd_v2(p1, p2, mask, dim=-1, eps=1e-8):
sd = torch.square(p1 - p2).sum(dim=-1, keepdim=False)
msd = torch.sum(sd * mask, dim=dim) / (eps + torch.sum(mask, dim=dim))
return torch.sqrt(msd + eps)
def kabsch_rmsd_v2(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
true_atom_mask: torch.Tensor,
pred_atom_mask: torch.Tensor,
num_dim: int = 1,
):
r, t = get_optimal_transform_v2(
true_atom_pos, pred_atom_pos, true_atom_mask * pred_atom_mask, num_dim
)
aligned_true_atom_pos = apply_optimal_transform_v2(true_atom_pos, r, t)
reduce_dim = tuple(-k - 1 for k in range(num_dim))
return compute_rmsd_v2(
aligned_true_atom_pos,
pred_atom_pos,
true_atom_mask * pred_atom_mask,
dim=reduce_dim,
)
def compute_metric(features, out, eps=1e-6):
ca_idx = rc.atom_order["CA"]
true_ca: torch.Tensor = features["all_atom_positions"][..., ca_idx, :]
pred_ca = out["final_atom_positions"][..., ca_idx, :]
mask: torch.Tensor = features["all_atom_mask"] * out["final_atom_mask"]
mask = mask[..., ca_idx]
r, t = get_optimal_transform_v2(pred_ca, true_ca, mask, num_dim=1)
aln_pred_ca: torch.Tensor = apply_optimal_transform_v2(pred_ca, r, t)
sd = (aln_pred_ca - true_ca).square().sum(dim=-1) # [*, n]
nres = mask.sum(dim=-1, keepdim=True) # [*, 1]
d0 = 1.24 * torch.clamp(nres, min=15) ** (1.0 / 3.0) - 1.8
tm_term = 1.0 / (1.0 + (sd / d0) ** 2)
msd = torch.sum(sd * mask, dim=-1) / (eps + torch.sum(mask, dim=-1))
rmsd = torch.sqrt(msd + eps)
tm = torch.sum(tm_term * mask, dim=-1) / (eps + torch.sum(mask, dim=-1))
return {
"rmsd": rmsd.data,
"tm_score": tm.data,
}
import torch
from unifold.data import residue_constants as rc
def softmax_cross_entropy(logits, labels):
loss = -1 * torch.sum(
labels * torch.nn.functional.log_softmax(logits.float(), dim=-1),
dim=-1,
)
return loss
def sigmoid_cross_entropy(logits, labels):
logits = logits.float()
log_p = torch.nn.functional.logsigmoid(logits)
log_not_p = torch.nn.functional.logsigmoid(-logits)
loss = -labels * log_p - (1 - labels) * log_not_p
return loss
def masked_mean(mask, value, dim, eps=1e-10, keepdim=False):
mask = mask.expand(*value.shape)
return torch.sum(mask * value, dim=dim, keepdim=keepdim) / (
eps + torch.sum(mask, dim=dim, keepdim=keepdim)
)
import torch
from typing import Dict
from unicore.utils import one_hot
from .utils import masked_mean
from unifold.data import residue_constants as rc
def between_residue_bond_loss(
pred_atom_positions: torch.Tensor,
pred_atom_mask: torch.Tensor,
residue_index: torch.Tensor,
aatype: torch.Tensor,
asym_id: torch.Tensor,
tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0,
eps=1e-6,
) -> Dict[str, torch.Tensor]:
pred_atom_positions = pred_atom_positions.float()
pred_atom_mask = pred_atom_mask.float()
this_ca_pos = pred_atom_positions[..., :-1, 1, :]
this_ca_mask = pred_atom_mask[..., :-1, 1]
this_c_pos = pred_atom_positions[..., :-1, 2, :]
this_c_mask = pred_atom_mask[..., :-1, 2]
next_n_pos = pred_atom_positions[..., 1:, 0, :]
next_n_mask = pred_atom_mask[..., 1:, 0]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
# mask gap between different chains
if asym_id is not None:
has_no_gap_mask &= asym_id[..., :-1] == asym_id[..., 1:]
has_no_gap_mask = has_no_gap_mask.float()
c_n_bond_length = torch.sqrt(
eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1)
)
next_is_proline = (aatype[..., 1:] == rc.resname_to_idx["PRO"]).float()
gt_length = (1.0 - next_is_proline) * rc.between_res_bond_length_c_n[
0
] + next_is_proline * rc.between_res_bond_length_c_n[1]
gt_stddev = (1.0 - next_is_proline) * rc.between_res_bond_length_stddev_c_n[
0
] + next_is_proline * rc.between_res_bond_length_stddev_c_n[1]
c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
c_n_loss_per_residue = torch.nn.functional.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
c_n_violation_mask = (
mask * (c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)).float()
)
ca_c_bond_length = torch.sqrt(
eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1)
)
n_ca_bond_length = torch.sqrt(
eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1)
)
c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None]
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None]
n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None]
ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1)
gt_angle = rc.between_res_cos_angles_ca_c_n[0]
gt_stddev = rc.between_res_bond_length_stddev_c_n[0]
ca_c_n_cos_angle_error = torch.sqrt(eps + (ca_c_n_cos_angle - gt_angle) ** 2)
ca_c_n_loss_per_residue = torch.nn.functional.relu(
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev
)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
ca_c_n_violation_mask = mask * (
ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)
)
c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1)
gt_angle = rc.between_res_cos_angles_c_n_ca[0]
gt_stddev = rc.between_res_cos_angles_c_n_ca[1]
c_n_ca_cos_angle_error = torch.sqrt(eps + torch.square(c_n_ca_cos_angle - gt_angle))
c_n_ca_loss_per_residue = torch.nn.functional.relu(
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
)
per_residue_loss_sum = (
c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
)
per_residue_loss_sum = 0.5 * (
torch.nn.functional.pad(per_residue_loss_sum, (0, 1))
+ torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
)
violation_mask = torch.max(
torch.stack(
[c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask],
dim=-2,
),
dim=-2,
)[0]
violation_mask = torch.maximum(
torch.nn.functional.pad(violation_mask, (0, 1)),
torch.nn.functional.pad(violation_mask, (1, 0)),
)
return {
"c_n_loss_mean": c_n_loss,
"ca_c_n_loss_mean": ca_c_n_loss,
"c_n_ca_loss_mean": c_n_ca_loss,
"per_residue_loss_sum": per_residue_loss_sum,
"per_residue_violation_mask": violation_mask,
}
def between_residue_clash_loss(
atom14_pred_positions: torch.Tensor,
atom14_atom_exists: torch.Tensor,
atom14_atom_radius: torch.Tensor,
residue_index: torch.Tensor,
asym_id: torch.Tensor,
overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5,
) -> Dict[str, torch.Tensor]:
atom14_pred_positions = atom14_pred_positions.float()
fp_type = atom14_pred_positions.dtype
dists = torch.sqrt(
1e-10
+ torch.sum(
(
atom14_pred_positions[..., :, None, :, None, :]
- atom14_pred_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
dists_mask = (
atom14_atom_exists[..., :, None, :, None]
* atom14_atom_exists[..., None, :, None, :]
).type(fp_type)
dists_mask = (
dists_mask
* (
residue_index[..., :, None, None, None]
<= residue_index[..., None, :, None, None]
).float()
)
diagonal = (
residue_index[..., :, None, None, None]
== residue_index[..., None, :, None, None]
)
if asym_id is not None:
in_one_chain = (
asym_id[..., :, None, None, None] == asym_id[..., None, :, None, None]
)
diagonal = diagonal & in_one_chain
dists_mask = dists_mask * (1.0 - (diagonal).float())
c_one_hot = one_hot(residue_index.new_tensor(2), num_classes=14)
c_one_hot = c_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
)
c_one_hot = c_one_hot.type(fp_type)
n_one_hot = one_hot(residue_index.new_tensor(0), num_classes=14)
n_one_hot = n_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape
)
n_one_hot = n_one_hot.type(fp_type)
neighbour_mask = (residue_index[..., :, None] + 1) == residue_index[..., None, :]
if asym_id is not None:
neighbour_mask &= asym_id[..., :, None] == asym_id[..., None, :]
neighbour_mask = neighbour_mask[..., None, None].float()
c_n_bonds = (
neighbour_mask
* c_one_hot[..., None, None, :, None]
* n_one_hot[..., None, None, None, :]
)
dists_mask = dists_mask * (1.0 - c_n_bonds)
cys = rc.restype_name_to_atom14_names["CYS"]
cys_sg_idx = cys.index("SG")
cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
cys_sg_idx = cys_sg_idx.reshape(*((1,) * len(residue_index.shape[:-1])), 1).squeeze(
-1
)
cys_sg_one_hot = one_hot(cys_sg_idx, num_classes=14)
disulfide_bonds = (
cys_sg_one_hot[..., None, None, :, None]
* cys_sg_one_hot[..., None, None, None, :]
)
dists_mask = dists_mask * (1.0 - disulfide_bonds)
dists_lower_bound = dists_mask * (
atom14_atom_radius[..., :, None, :, None].float()
+ atom14_atom_radius[..., None, :, None, :].float()
)
dists_to_low_error = dists_mask * torch.nn.functional.relu(
dists_lower_bound - overlap_tolerance_soft - dists
)
mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
dists_to_low_error, dim=(-3, -1)
)
clash_mask = (
dists_mask * (dists < (dists_lower_bound - overlap_tolerance_hard)).float()
)
per_atom_clash_mask = torch.maximum(
torch.amax(clash_mask, dim=(-4, -2)),
torch.amax(clash_mask, dim=(-3, -1)),
)
per_atom_clash_count = torch.sum(clash_mask, dim=(-4, -2)) + torch.sum(
clash_mask, dim=(-3, -1)
)
return {
"mean_loss": mean_loss,
"per_atom_loss_sum": per_atom_loss_sum,
"per_atom_clash_mask": per_atom_clash_mask,
"per_atom_clash_count": per_atom_clash_count,
}
def within_residue_violations(
atom14_pred_positions: torch.Tensor,
atom14_atom_exists: torch.Tensor,
atom14_dists_lower_bound: torch.Tensor,
atom14_dists_upper_bound: torch.Tensor,
tighten_bounds_for_loss=0.0,
) -> Dict[str, torch.Tensor]:
atom14_pred_positions = atom14_pred_positions.float()
atom14_atom_exists = atom14_atom_exists.float()
dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None]
dists_masks = dists_masks.reshape(
*((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape
)
dists_masks = (
atom14_atom_exists[..., :, :, None]
* atom14_atom_exists[..., :, None, :]
* dists_masks
)
dists = torch.sqrt(
1e-10
+ torch.sum(
(
atom14_pred_positions[..., :, :, None, :]
- atom14_pred_positions[..., :, None, :, :]
)
** 2,
dim=-1,
)
)
dists_to_low_error = torch.nn.functional.relu(
atom14_dists_lower_bound + tighten_bounds_for_loss - dists
)
dists_to_high_error = torch.nn.functional.relu(
dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)
)
loss = dists_masks * (dists_to_low_error + dists_to_high_error)
per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1)
violations = (
dists_masks
* (
(dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
).float()
)
per_atom_violations = torch.maximum(
torch.max(violations, dim=-2)[0], torch.max(violations, dim=-1)[0]
)
per_atom_clash_count = torch.sum(violations, dim=-2) + torch.sum(violations, dim=-1)
return {
"per_atom_loss_sum": per_atom_loss_sum,
"per_atom_violations": per_atom_violations,
"per_atom_clash_count": per_atom_clash_count,
}
def find_structural_violations(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
violation_tolerance_factor: float,
clash_overlap_tolerance: float,
**kwargs,
) -> Dict[str, torch.Tensor]:
atom14_pred_positions = atom14_pred_positions.float()
connection_violations = between_residue_bond_loss(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch["residue_index"],
aatype=batch["aatype"],
asym_id=batch["asym_id"] if "asym_id" in batch else None,
tolerance_factor_soft=violation_tolerance_factor,
tolerance_factor_hard=violation_tolerance_factor,
)
atomtype_radius = [rc.van_der_waals_radius[name[0]] for name in rc.atom_types]
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
atom14_atom_radius = (
batch["atom14_atom_exists"] * atomtype_radius[batch["residx_atom14_to_atom37"]]
)
between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_atom_radius=atom14_atom_radius,
residue_index=batch["residue_index"],
asym_id=batch["asym_id"] if "asym_id" in batch else None,
overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=clash_overlap_tolerance,
)
restype_atom14_bounds = rc.make_atom14_dists_bounds(
overlap_tolerance=clash_overlap_tolerance,
bond_length_tolerance_factor=violation_tolerance_factor,
)
atom14_atom_exists = batch["atom14_atom_exists"]
atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
restype_atom14_bounds["lower_bound"]
)[batch["aatype"]]
atom14_dists_upper_bound = atom14_pred_positions.new_tensor(
restype_atom14_bounds["upper_bound"]
)[batch["aatype"]]
residue_violations = within_residue_violations(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_dists_lower_bound=atom14_dists_lower_bound,
atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0,
)
per_residue_violations_mask = torch.max(
torch.stack(
[
connection_violations["per_residue_violation_mask"],
torch.max(between_residue_clashes["per_atom_clash_mask"], dim=-1)[0],
torch.max(residue_violations["per_atom_violations"], dim=-1)[0],
],
dim=-1,
),
dim=-1,
)[0]
return {
"between_residues": {
"bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"],
"angles_ca_c_n_loss_mean": connection_violations["ca_c_n_loss_mean"],
"angles_c_n_ca_loss_mean": connection_violations["c_n_ca_loss_mean"],
"connections_per_residue_loss_sum": connection_violations[
"per_residue_loss_sum"
],
"connections_per_residue_violation_mask": connection_violations[
"per_residue_violation_mask"
],
"clashes_mean_loss": between_residue_clashes["mean_loss"],
"clashes_per_atom_loss_sum": between_residue_clashes["per_atom_loss_sum"],
"clashes_per_atom_clash_mask": between_residue_clashes[
"per_atom_clash_mask"
],
"clashes_per_atom_clash_count": between_residue_clashes[
"per_atom_clash_count"
],
},
"within_residues": {
"per_atom_loss_sum": residue_violations["per_atom_loss_sum"],
"per_atom_violations": residue_violations["per_atom_violations"],
"per_atom_clash_count": residue_violations["per_atom_clash_count"],
},
"total_per_residue_violations_mask": per_residue_violations_mask,
}
def extreme_ca_ca_distance_violations(
pred_atom_positions: torch.Tensor,
pred_atom_mask: torch.Tensor,
residue_index: torch.Tensor,
max_angstrom_tolerance=1.5,
eps=1e-6,
) -> torch.Tensor:
pred_atom_positions = pred_atom_positions.float()
this_ca_pos = pred_atom_positions[..., :-1, 1, :]
this_ca_mask = pred_atom_mask[..., :-1, 1]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]
has_no_gap_mask = (
(residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
).float()
ca_ca_distance = torch.sqrt(
eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)
)
violations = (ca_ca_distance - rc.ca_ca) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
mean = masked_mean(mask, violations, -1)
return mean
def compute_violation_metrics(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
violations: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Compute several metrics to assess the structural violations."""
atom14_pred_positions = atom14_pred_positions.float()
ret = {}
extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch["residue_index"],
)
ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
ret["violations_between_residue_bond"] = masked_mean(
batch["seq_mask"],
violations["between_residues"]["connections_per_residue_violation_mask"],
dim=-1,
)
ret["violations_between_residue_clash"] = masked_mean(
mask=batch["seq_mask"],
value=torch.max(
violations["between_residues"]["clashes_per_atom_clash_mask"],
dim=-1,
)[0],
dim=-1,
)
ret["violations_within_residue"] = masked_mean(
mask=batch["seq_mask"],
value=torch.max(violations["within_residues"]["per_atom_violations"], dim=-1)[
0
],
dim=-1,
)
ret["violations_per_residue"] = masked_mean(
mask=batch["seq_mask"],
value=violations["total_per_residue_violations_mask"],
dim=-1,
)
return ret
def violation_loss(
violations: Dict[str, torch.Tensor],
eps=1e-6,
loss_dict=None,
bond_angle_loss_weight: float = 0.3,
**kwargs,
) -> torch.Tensor:
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"]
+ violations["within_residues"]["per_atom_loss_sum"],
dim=(-1, -2),
)
cnt_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_clash_count"]
+ violations["within_residues"]["per_atom_clash_count"],
dim=(-1, -2),
)
l_clash = l_clash / (eps + cnt_clash)
loss = (
violations["between_residues"]["bonds_c_n_loss_mean"]
+ bond_angle_loss_weight
* violations["between_residues"]["angles_ca_c_n_loss_mean"]
+ bond_angle_loss_weight
* violations["between_residues"]["angles_c_n_ca_loss_mean"]
+ l_clash
)
loss_dict["violation"] = loss.data
return loss
import logging
from typing import Any
from unicore.models import BaseUnicoreModel, register_model, register_model_architecture
from unifold.modules.alphafold import AlphaFold
from unifold.config import model_config
logger = logging.getLogger(__name__)
@register_model("af2")
class AlphafoldModel(BaseUnicoreModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--model-name",
help="choose the model config",
)
def __init__(self, args):
super().__init__()
base_architecture(args)
self.args = args
config = model_config(
self.args.model_name,
train=True,
)
self.model = AlphaFold(config)
self.config = config
def half(self):
self.model = self.model.half()
return self
def bfloat16(self):
self.model = self.model.bfloat16()
return self
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
return cls(args)
def forward(self, batch, **kwargs):
outputs = self.model.forward(batch)
return outputs, self.config.loss
@register_model_architecture("af2", "af2")
def base_architecture(args):
args.model_name = getattr(args, "model_name", "model_2")
"""Modules of Uni-Fold models."""
from unicore.utils import (
set_jit_fusion_options,
)
set_jit_fusion_options()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment