Commit 51556d52 authored by Christina Floristean's avatar Christina Floristean
Browse files

Added multimer changes for loss functions

parent fbfbd808
......@@ -648,6 +648,7 @@ config = mlc.ConfigDict(
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"average_clashes": False,
"eps": eps, # 1e-6,
"weight": 0.0,
},
......@@ -660,6 +661,12 @@ config = mlc.ConfigDict(
"weight": 0.,
"enabled": tm_enabled,
},
"chain_center_of_mass": {
"clamp_distance": -4.0,
"weight": 0.,
"eps": eps,
"enabled": False,
},
"eps": eps,
},
"ema": {"decay": 0.999},
......@@ -802,7 +809,9 @@ multimer_model_config_update = {
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": tm_enabled,
"ptm_weight": 0.2,
"iptm_weight": 0.8,
"enabled": True,
},
"masked_msa": {
"c_m": c_m,
......@@ -813,5 +822,81 @@ multimer_model_config_update = {
"c_out": 37,
},
},
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": eps, # 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": eps, # 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.0,
},
"fape": {
"intra_chain_backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5,
},
"interface": {
"clamp_distance": 30.0,
"loss_unit_distance": 20.0,
"weight": 0.5,
},
"sidechain": {
"clamp_distance": 10.0,
"length_scale": 10.0,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
},
"plddt_loss": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.0,
"no_bins": 50,
"eps": eps, # 1e-10,
"weight": 0.01,
},
"masked_msa": {
"num_classes": 23,
"eps": eps, # 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": eps, # 1e-6,
"weight": 1.0,
},
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"average_clashes": True,
"eps": eps, # 1e-6,
"weight": 0.03, # Not finetuning
},
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.1,
"enabled": True,
},
"chain_center_of_mass": {
"clamp_distance": -4.0,
"weight": 0.05,
"eps": eps,
"enabled": True,
},
"eps": eps,
},
"recycle_early_stop_tolerance": 0.5
}
......@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module):
if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm(
aux_out["ptm_score"] = compute_tm(
tm_logits, **self.config.tm
)
asym_id = outputs.get("asym_id")
if asym_id is not None:
aux_out["iptm_score"] = compute_tm(
tm_logits, asym_id=asym_id, interface=True, **self.config.tm
)
aux_out["weighted_ptm_score"] = (self.config.tm["iptm_weight"] * aux_out["iptm_score"]
+ self.config.tm["ptm_weight"] * aux_out["ptm_score"])
aux_out.update(
compute_predicted_aligned_error(
tm_logits,
......
......@@ -555,6 +555,9 @@ class AlphaFold(nn.Module):
else:
break
if "asym_id" in batch:
outputs["asym_id"] = feats["asym_id"]
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
......
......@@ -435,7 +435,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
# reduced-precision modes
a_std = a.std()
b_std = b.std()
if(a_std != 0. and b_std != 0.):
if(is_fp16_enabled() and a_std != 0. and b_std != 0.):
a = a / a.std()
b = b / b.std()
......@@ -589,8 +589,11 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a = a / a.std()
b = b / b.std()
a_std = a.std()
b_std = b.std()
if (is_fp16_enabled() and a_std != 0. and b_std != 0.):
a = a / a.std()
b = b / b.std()
if (is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
......
......@@ -193,7 +193,7 @@ def square_euclidean_distance(
difference = vec1 - vec2
distance = difference.dot(difference)
if epsilon:
distance = torch.maximum(distance, epsilon)
distance = torch.clamp(distance, min=epsilon)
return distance
......
......@@ -617,7 +617,7 @@ def generate_translation_dict(model, version, is_multimer=False):
translations["evoformer"].update(template_param_dict)
if "_ptm" in version:
if is_multimer or "_ptm" in version:
translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear)
}
......
......@@ -25,6 +25,8 @@ from typing import Dict, Optional, Tuple
from openfold.np import residue_constants
from openfold.utils import feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.vector import Vec3Array, euclidean_distance
from openfold.utils.all_atom_multimer import get_rc_tensor
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
......@@ -87,6 +89,7 @@ def compute_fape(
target_positions: torch.Tensor,
positions_mask: torch.Tensor,
length_scale: float,
pair_mask: Optional[torch.Tensor] = None,
l1_clamp_distance: Optional[float] = None,
eps=1e-8,
) -> torch.Tensor:
......@@ -108,6 +111,9 @@ def compute_fape(
[*, N_pts] positions mask
length_scale:
Length scale by which the loss is divided
pair_mask:
[*, N_frames, N_pts] mask to use for
separating intra- from inter-chain losses.
l1_clamp_distance:
Cutoff above which distance errors are disregarded
eps:
......@@ -134,21 +140,30 @@ def compute_fape(
normed_error = normed_error * frames_mask[..., None]
normed_error = normed_error * positions_mask[..., None, :]
# FP16-friendly averaging. Roughly equivalent to:
#
# norm_factor = (
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
if pair_mask is not None:
normed_error = normed_error * pair_mask
normed_error = torch.sum(normed_error, dim=(-1, -2))
mask = frames_mask[..., None] * positions_mask[..., None, :] * pair_mask
norm_factor = torch.sum(mask, dim=(-2, -1))
normed_error = normed_error / (eps + norm_factor)
else:
# FP16-friendly averaging. Roughly equivalent to:
#
# norm_factor = (
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
return normed_error
......@@ -157,6 +172,7 @@ def backbone_loss(
backbone_rigid_tensor: torch.Tensor,
backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor,
pair_mask: Optional[torch.Tensor] = None,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0,
loss_unit_distance: float = 10.0,
......@@ -184,6 +200,7 @@ def backbone_loss(
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_rigid_mask[None],
pair_mask=pair_mask,
l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance,
eps=eps,
......@@ -196,6 +213,7 @@ def backbone_loss(
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_rigid_mask[None],
pair_mask=pair_mask,
l1_clamp_distance=None,
length_scale=loss_unit_distance,
eps=eps,
......@@ -253,6 +271,7 @@ def sidechain_loss(
sidechain_atom_pos,
renamed_atom14_gt_positions,
renamed_atom14_gt_exists,
pair_mask=None,
l1_clamp_distance=clamp_distance,
length_scale=length_scale,
eps=eps,
......@@ -266,10 +285,29 @@ def fape_loss(
batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict,
) -> torch.Tensor:
bb_loss = backbone_loss(
traj=out["sm"]["frames"],
**{**batch, **config.backbone},
)
traj = out["sm"]["frames"]
asym_id = batch.get("asym_id")
if asym_id is not None:
intra_chain_mask = (asym_id[..., None] == asym_id[..., None, :]).to(dtype=traj.dtype)
intra_chain_bb_loss = backbone_loss(
traj=traj,
pair_mask=intra_chain_mask,
**{**batch, **config.intra_chain_backbone},
)
interface_bb_loss = backbone_loss(
traj=traj,
pair_mask=1. - intra_chain_mask,
**{**batch, **config.interface_backbone},
)
weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight
+ interface_bb_loss * config.interface_backbone.weight)
else:
bb_loss = backbone_loss(
traj=traj,
**{**batch, **config.backbone},
)
weighted_bb_loss = bb_loss * config.backbone.weight
sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"],
......@@ -277,7 +315,7 @@ def fape_loss(
**{**batch, **config.sidechain},
)
loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
loss = weighted_bb_loss + config.sidechain.weight * sc_loss
# Average over the batch dimension
loss = torch.mean(loss)
......@@ -654,7 +692,7 @@ def compute_tm(
n = residue_weights.shape[-1]
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
if interface:
pair_mask *= (asym_id[..., None] != asym_id[..., None, :])
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
predicted_tm_term *= pair_mask
......@@ -891,6 +929,7 @@ def between_residue_clash_loss(
atom14_atom_exists: torch.Tensor,
atom14_atom_radius: torch.Tensor,
residue_index: torch.Tensor,
asym_id: Optional[torch.Tensor] = None,
overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5,
eps=1e-10,
......@@ -966,9 +1005,13 @@ def between_residue_clash_loss(
)
n_one_hot = n_one_hot.type(fp_type)
neighbour_mask = (
residue_index[..., :, None, None, None] + 1
) == residue_index[..., None, :, None, None]
neighbour_mask = (residue_index[..., :, None] + 1) == residue_index[..., None, :]
if asym_id is not None:
neighbour_mask = neighbour_mask & (asym_id[..., :, None] == asym_id[..., None, :])
neighbour_mask = neighbour_mask[..., None, None]
c_n_bonds = (
neighbour_mask
* c_one_hot[..., None, None, :, None]
......@@ -1010,26 +1053,29 @@ def between_residue_clash_loss(
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
dists_to_low_error, axis=(-3, -1)
dists_to_low_error, dim=(-3, -1)
)
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask = dists_mask * (
dists < (dists_lower_bound - overlap_tolerance_hard)
dists < (dists_lower_bound - overlap_tolerance_hard)
)
per_atom_num_clash = torch.sum(clash_mask, dim=(-4, -2)) + torch.sum(clash_mask, dim=(-3, -1))
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask = torch.maximum(
torch.amax(clash_mask, axis=(-4, -2)),
torch.amax(clash_mask, axis=(-3, -1)),
torch.amax(clash_mask, dim=(-4, -2)),
torch.amax(clash_mask, dim=(-3, -1)),
)
return {
"mean_loss": mean_loss, # shape ()
"per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
"per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
"per_atom_num_clash": per_atom_num_clash # shape (N, 14)
}
......@@ -1109,6 +1155,8 @@ def within_residue_violations(
(dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
)
per_atom_num_clash = torch.sum(violations, dim=-2) + torch.sum(violations, dim=-1)
# Compute the per atom violations.
per_atom_violations = torch.maximum(
torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
......@@ -1117,6 +1165,7 @@ def within_residue_violations(
return {
"per_atom_loss_sum": per_atom_loss_sum,
"per_atom_violations": per_atom_violations,
"per_atom_num_clash": per_atom_num_clash
}
......@@ -1146,11 +1195,24 @@ def find_structural_violations(
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
atom14_atom_radius = (
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
)
#TODO: Consolidate monomer/multimer modes
asym_id = batch.get("asym_id")
if asym_id is not None:
residx_atom14_to_atom37 = get_rc_tensor(
residue_constants.RESTYPE_ATOM14_TO_ATOM37, batch["aatype"]
)
atom14_atom_radius = (
batch["atom14_atom_exists"]
* atomtype_radius[residx_atom14_to_atom37.long()]
)
else:
atom14_atom_radius = (
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
)
# Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss(
......@@ -1158,6 +1220,7 @@ def find_structural_violations(
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_atom_radius=atom14_atom_radius,
residue_index=batch["residue_index"],
asym_id=asym_id,
overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=clash_overlap_tolerance,
)
......@@ -1220,6 +1283,9 @@ def find_structural_violations(
"clashes_per_atom_clash_mask": between_residue_clashes[
"per_atom_clash_mask"
], # (N, 14)
"clashes_per_atom_num_clash": between_residue_clashes[
"per_atom_num_clash"
], # (N, 14)
},
"within_residues": {
"per_atom_loss_sum": residue_violations[
......@@ -1228,6 +1294,9 @@ def find_structural_violations(
"per_atom_violations": residue_violations[
"per_atom_violations"
], # (N, 14),
"per_atom_num_clash": residue_violations[
"per_atom_num_clash"
], # (N, 14)
},
"total_per_residue_violations_mask": per_residue_violations_mask, # (N)
}
......@@ -1349,15 +1418,21 @@ def compute_violation_metrics_np(
def violation_loss(
violations: Dict[str, torch.Tensor],
atom14_atom_exists: torch.Tensor,
average_clashes: bool = False,
eps=1e-6,
**kwargs,
) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"]
+ violations["within_residues"]["per_atom_loss_sum"]
)
l_clash = l_clash / (eps + num_atoms)
per_atom_clash = (violations["between_residues"]["clashes_per_atom_loss_sum"] +
violations["within_residues"]["per_atom_loss_sum"])
if average_clashes:
num_clash = (violations["between_residues"]["clashes_per_atom_num_clash"] +
violations["within_residues"]["per_atom_num_clash"])
per_atom_clash = per_atom_clash / (num_clash + eps)
l_clash = torch.sum(per_atom_clash) / (eps + num_atoms)
loss = (
violations["between_residues"]["bonds_c_n_loss_mean"]
+ violations["between_residues"]["angles_ca_c_n_loss_mean"]
......@@ -1533,6 +1608,64 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
return loss
def chain_center_of_mass_loss(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
asym_id: torch.Tensor,
clamp_distance: float = -4.0,
weight: float = 0.05,
eps: float = 1e-10
) -> torch.Tensor:
"""
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
Args:
all_atom_pred_pos:
[*, N_pts, 37, 3] All-atom predicted atom positions
all_atom_positions:
[*, N_pts, 37, 3] Ground truth all-atom positions
all_atom_mask:
[*, N_pts, 37] All-atom positions mask
asym_id:
[*, N_pts] Chain asym IDs
clamp_distance:
Cutoff above which distance errors are disregarded
weight:
Weight for loss
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True)
one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float()
def get_chain_center_of_mass(pos):
center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2)
centers = center_sum / (torch.sum(chain_pos_mask, dim=-1, keepdim=True) + eps)
return Vec3Array.from_array(centers)
pred_centers = get_chain_center_of_mass(all_atom_pred_pos) # [B, NC, 3]
true_centers = get_chain_center_of_mass(all_atom_positions) # [B, NC, 3]
pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps)
true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps)
losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2
loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :]
loss = masked_mean(loss_mask, losses, dim=(-1, -2))
return loss
class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement"""
def __init__(self, config):
......@@ -1585,7 +1718,7 @@ class AlphaFoldLoss(nn.Module):
),
"violation": lambda: violation_loss(
out["violation"],
**batch,
**{**batch, **self.config.violation},
),
}
......@@ -1595,6 +1728,12 @@ class AlphaFoldLoss(nn.Module):
**{**batch, **out, **self.config.tm},
)
if (self.config.chain_center_of_mass.enabled):
loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss(
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.chain_center_of_mass},
)
cum_loss = 0.
losses = {}
for loss_name, loss_fn in loss_fns.items():
......
......@@ -6,7 +6,7 @@ consts = mlc.ConfigDict(
"is_multimer": True, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
"n_res": 11,
"n_res": 22,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
......
......@@ -29,14 +29,16 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
pieces = []
asym_ids = []
final_idx = n_chain - 1
for idx in range(n_chain - 1):
n_stop = (n_res - sum(pieces) - n_chain + idx - min_chain_len)
if n_stop <= min_chain_len:
final_idx = idx
break
piece = randint(min_chain_len, n_stop)
pieces.append(piece)
asym_ids.extend(piece * [idx])
asym_ids.extend((n_res - sum(pieces)) * [n_chain - 1])
asym_ids.extend((n_res - sum(pieces)) * [final_idx])
return np.array(asym_ids).astype(np.int64)
......
......@@ -20,6 +20,7 @@ import unittest
import ml_collections as mlc
from openfold.data import data_transforms
from openfold.np import residue_constants
from openfold.utils.rigid_utils import (
Rotation,
Rigid,
......@@ -42,6 +43,8 @@ from openfold.utils.loss import (
sidechain_loss,
tm_loss,
compute_plddt,
compute_tm,
chain_center_of_mass_loss
)
from openfold.utils.tensor_utils import (
tree_map,
......@@ -233,11 +236,24 @@ class TestLoss(unittest.TestCase):
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
atom_radius = np.random.rand(n_res, 14).astype(np.float32)
res_ind = np.arange(
n_res,
)
asym_id = random_asym_ids(n_res)
residx_atom14_to_atom37 = np.random.randint(0, 37, (n_res, 14)).astype(np.int64)
atomtype_radius = [
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = np.array(atomtype_radius).astype(np.float32)
atom_radius = (
atom_exists
* atomtype_radius[residx_atom14_to_atom37]
)
asym_id = None
if consts.is_multimer:
asym_id = random_asym_ids(n_res)
out_gt = f.apply(
{},
......@@ -256,6 +272,7 @@ class TestLoss(unittest.TestCase):
torch.tensor(atom_exists).cuda(),
torch.tensor(atom_radius).cuda(),
torch.tensor(res_ind).cuda(),
torch.tensor(asym_id).cuda() if asym_id is not None else None,
)
out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
......@@ -279,6 +296,36 @@ class TestLoss(unittest.TestCase):
torch.max(torch.abs(out_gt - out_repro)) < consts.eps
)
@compare_utils.skip_unless_alphafold_installed()
def test_compute_ptm_compare(self):
n_res = consts.n_res
max_bin = 31
no_bins = 64
logits = np.random.rand(n_res, n_res, no_bins)
boundaries = np.linspace(0, max_bin, num=(no_bins - 1))
ptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries)
ptm_gt = torch.tensor(ptm_gt)
logits_t = torch.tensor(logits)
ptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin)
self.assertTrue(
torch.max(torch.abs(ptm_gt - ptm_repro)) < consts.eps
)
if consts.is_multimer:
asym_id = random_asym_ids(n_res)
iptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries,
asym_id=asym_id, interface=True)
iptm_gt = torch.tensor(iptm_gt)
iptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin,
asym_id=torch.tensor(asym_id), interface=True)
self.assertTrue(
torch.max(torch.abs(iptm_gt - iptm_repro)) < consts.eps
)
def test_find_structural_violations(self):
n = consts.n_res
......@@ -335,9 +382,11 @@ class TestLoss(unittest.TestCase):
"residx_atom14_to_atom37": np.random.randint(
0, 37, (n_res, 14)
).astype(np.int64),
"asym_id": random_asym_ids(n_res)
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
pred_pos = np.random.rand(n_res, 14, 3)
config = mlc.ConfigDict(
......@@ -632,6 +681,40 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_violation_loss(self):
config = compare_utils.get_alphafold_config()
c_viol = config.model.heads.structure_module
n_res = consts.n_res
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
batch = data_transforms.make_atom14_masks(batch)
loss_sum_clash = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
average_clashes=False, **batch
)
loss_sum_clash = loss_sum_clash.cpu()
loss_avg_clash = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
average_clashes=True, **batch
)
loss_avg_clash = loss_avg_clash.cpu()
@compare_utils.skip_unless_alphafold_installed()
def test_violation_loss_compare(self):
config = compare_utils.get_alphafold_config()
......@@ -680,10 +763,12 @@ class TestLoss(unittest.TestCase):
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)),
"asym_id": random_asym_ids(n_res)
"aatype": np.random.randint(0, 21, (n_res,))
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
......@@ -801,8 +886,7 @@ class TestLoss(unittest.TestCase):
"backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
np.float32
),
"use_clamped_fape": np.array(0.0),
"asym_id": random_asym_ids(n_res)
"use_clamped_fape": np.array(0.0)
}
value = {
......@@ -814,6 +898,9 @@ class TestLoss(unittest.TestCase):
),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
out_gt = f.apply({}, None, batch, value)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
......@@ -826,8 +913,18 @@ class TestLoss(unittest.TestCase):
)
batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu()
if consts.is_multimer:
intra_chain_mask = (batch["asym_id"][..., None]
== batch["asym_id"][..., None, :]).to(dtype=value["traj"].dtype)
intra_chain_out = backbone_loss(traj=value["traj"], pair_mask=intra_chain_mask,
**{**batch, **c_sm.intra_chain_fape})
interface_out = backbone_loss(traj=value["traj"], pair_mask=1. - intra_chain_mask,
**{**batch, **c_sm.interface_fape})
out_repro = intra_chain_out + interface_out
out_repro = out_repro.cpu()
else:
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
......@@ -869,14 +966,14 @@ class TestLoss(unittest.TestCase):
v["sidechains"] = {}
v["sidechains"][
"frames"
] = alphafold.model.r3.rigids_from_tensor4x4(
] = self.am_rigid.rigids_from_tensor4x4(
value["sidechains"]["frames"]
)
v["sidechains"]["atom_pos"] = alphafold.model.r3.vecs_from_tensor(
v["sidechains"]["atom_pos"] = self.am_rigid.vecs_from_tensor(
value["sidechains"]["atom_pos"]
)
v.update(
alphafold.model.folding.compute_renamed_ground_truth(
self.am_fold.compute_renamed_ground_truth(
batch,
atom14_pred_positions,
)
......@@ -907,9 +1004,6 @@ class TestLoss(unittest.TestCase):
),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
def _build_extra_feats_np():
b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
b = data_transforms.make_atom14_masks(b)
......@@ -950,7 +1044,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(consts.is_multimer or "ptm" not in consts.model, "Not enabled for non-ptm models.")
@unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error
......@@ -1017,6 +1111,33 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_chain_center_of_mass_loss(self):
batch_size = consts.batch_size
n_res = consts.n_res
batch = {
"all_atom_positions": np.random.rand(batch_size, n_res, 37, 3).astype(np.float32) * 10.0,
"all_atom_mask": np.random.randint(0, 2, (batch_size, n_res, 37)).astype(np.float32),
"asym_id": np.stack([random_asym_ids(n_res) for _ in range(batch_size)])
}
config = {
"weight": 0.05,
"clamp_distance": -4.0,
}
final_atom_positions = torch.rand(batch_size, n_res, 37, 3).cuda()
to_tensor = lambda t: torch.tensor(t).cuda()
batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = chain_center_of_mass_loss(
all_atom_pred_pos=final_atom_positions,
**{**batch, **config},
)
out_repro = out_repro.cpu()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment