Commit 51556d52 authored by Christina Floristean's avatar Christina Floristean
Browse files

Added multimer changes for loss functions

parent fbfbd808
...@@ -648,6 +648,7 @@ config = mlc.ConfigDict( ...@@ -648,6 +648,7 @@ config = mlc.ConfigDict(
"violation": { "violation": {
"violation_tolerance_factor": 12.0, "violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5, "clash_overlap_tolerance": 1.5,
"average_clashes": False,
"eps": eps, # 1e-6, "eps": eps, # 1e-6,
"weight": 0.0, "weight": 0.0,
}, },
...@@ -660,6 +661,12 @@ config = mlc.ConfigDict( ...@@ -660,6 +661,12 @@ config = mlc.ConfigDict(
"weight": 0., "weight": 0.,
"enabled": tm_enabled, "enabled": tm_enabled,
}, },
"chain_center_of_mass": {
"clamp_distance": -4.0,
"weight": 0.,
"eps": eps,
"enabled": False,
},
"eps": eps, "eps": eps,
}, },
"ema": {"decay": 0.999}, "ema": {"decay": 0.999},
...@@ -802,7 +809,9 @@ multimer_model_config_update = { ...@@ -802,7 +809,9 @@ multimer_model_config_update = {
"tm": { "tm": {
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "no_bins": aux_distogram_bins,
"enabled": tm_enabled, "ptm_weight": 0.2,
"iptm_weight": 0.8,
"enabled": True,
}, },
"masked_msa": { "masked_msa": {
"c_m": c_m, "c_m": c_m,
...@@ -813,5 +822,81 @@ multimer_model_config_update = { ...@@ -813,5 +822,81 @@ multimer_model_config_update = {
"c_out": 37, "c_out": 37,
}, },
}, },
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": eps, # 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": eps, # 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.0,
},
"fape": {
"intra_chain_backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5,
},
"interface": {
"clamp_distance": 30.0,
"loss_unit_distance": 20.0,
"weight": 0.5,
},
"sidechain": {
"clamp_distance": 10.0,
"length_scale": 10.0,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
},
"plddt_loss": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.0,
"no_bins": 50,
"eps": eps, # 1e-10,
"weight": 0.01,
},
"masked_msa": {
"num_classes": 23,
"eps": eps, # 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": eps, # 1e-6,
"weight": 1.0,
},
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"average_clashes": True,
"eps": eps, # 1e-6,
"weight": 0.03, # Not finetuning
},
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.1,
"enabled": True,
},
"chain_center_of_mass": {
"clamp_distance": -4.0,
"weight": 0.05,
"eps": eps,
"enabled": True,
},
"eps": eps,
},
"recycle_early_stop_tolerance": 0.5 "recycle_early_stop_tolerance": 0.5
} }
...@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module): ...@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module):
if self.config.tm.enabled: if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"]) tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm( aux_out["ptm_score"] = compute_tm(
tm_logits, **self.config.tm tm_logits, **self.config.tm
) )
asym_id = outputs.get("asym_id")
if asym_id is not None:
aux_out["iptm_score"] = compute_tm(
tm_logits, asym_id=asym_id, interface=True, **self.config.tm
)
aux_out["weighted_ptm_score"] = (self.config.tm["iptm_weight"] * aux_out["iptm_score"]
+ self.config.tm["ptm_weight"] * aux_out["ptm_score"])
aux_out.update( aux_out.update(
compute_predicted_aligned_error( compute_predicted_aligned_error(
tm_logits, tm_logits,
......
...@@ -555,6 +555,9 @@ class AlphaFold(nn.Module): ...@@ -555,6 +555,9 @@ class AlphaFold(nn.Module):
else: else:
break break
if "asym_id" in batch:
outputs["asym_id"] = feats["asym_id"]
# Run auxiliary heads # Run auxiliary heads
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
......
...@@ -435,7 +435,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate): ...@@ -435,7 +435,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
# reduced-precision modes # reduced-precision modes
a_std = a.std() a_std = a.std()
b_std = b.std() b_std = b.std()
if(a_std != 0. and b_std != 0.): if(is_fp16_enabled() and a_std != 0. and b_std != 0.):
a = a / a.std() a = a / a.std()
b = b / b.std() b = b / b.std()
...@@ -589,6 +589,9 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate): ...@@ -589,6 +589,9 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
# Prevents overflow of torch.matmul in combine projections in # Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes # reduced-precision modes
a_std = a.std()
b_std = b.std()
if (is_fp16_enabled() and a_std != 0. and b_std != 0.):
a = a / a.std() a = a / a.std()
b = b / b.std() b = b / b.std()
......
...@@ -193,7 +193,7 @@ def square_euclidean_distance( ...@@ -193,7 +193,7 @@ def square_euclidean_distance(
difference = vec1 - vec2 difference = vec1 - vec2
distance = difference.dot(difference) distance = difference.dot(difference)
if epsilon: if epsilon:
distance = torch.maximum(distance, epsilon) distance = torch.clamp(distance, min=epsilon)
return distance return distance
......
...@@ -617,7 +617,7 @@ def generate_translation_dict(model, version, is_multimer=False): ...@@ -617,7 +617,7 @@ def generate_translation_dict(model, version, is_multimer=False):
translations["evoformer"].update(template_param_dict) translations["evoformer"].update(template_param_dict)
if "_ptm" in version: if is_multimer or "_ptm" in version:
translations["predicted_aligned_error_head"] = { translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear) "logits": LinearParams(model.aux_heads.tm.linear)
} }
......
...@@ -25,6 +25,8 @@ from typing import Dict, Optional, Tuple ...@@ -25,6 +25,8 @@ from typing import Dict, Optional, Tuple
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils import feats from openfold.utils import feats
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.vector import Vec3Array, euclidean_distance
from openfold.utils.all_atom_multimer import get_rc_tensor
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -87,6 +89,7 @@ def compute_fape( ...@@ -87,6 +89,7 @@ def compute_fape(
target_positions: torch.Tensor, target_positions: torch.Tensor,
positions_mask: torch.Tensor, positions_mask: torch.Tensor,
length_scale: float, length_scale: float,
pair_mask: Optional[torch.Tensor] = None,
l1_clamp_distance: Optional[float] = None, l1_clamp_distance: Optional[float] = None,
eps=1e-8, eps=1e-8,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -108,6 +111,9 @@ def compute_fape( ...@@ -108,6 +111,9 @@ def compute_fape(
[*, N_pts] positions mask [*, N_pts] positions mask
length_scale: length_scale:
Length scale by which the loss is divided Length scale by which the loss is divided
pair_mask:
[*, N_frames, N_pts] mask to use for
separating intra- from inter-chain losses.
l1_clamp_distance: l1_clamp_distance:
Cutoff above which distance errors are disregarded Cutoff above which distance errors are disregarded
eps: eps:
...@@ -134,6 +140,15 @@ def compute_fape( ...@@ -134,6 +140,15 @@ def compute_fape(
normed_error = normed_error * frames_mask[..., None] normed_error = normed_error * frames_mask[..., None]
normed_error = normed_error * positions_mask[..., None, :] normed_error = normed_error * positions_mask[..., None, :]
if pair_mask is not None:
normed_error = normed_error * pair_mask
normed_error = torch.sum(normed_error, dim=(-1, -2))
mask = frames_mask[..., None] * positions_mask[..., None, :] * pair_mask
norm_factor = torch.sum(mask, dim=(-2, -1))
normed_error = normed_error / (eps + norm_factor)
else:
# FP16-friendly averaging. Roughly equivalent to: # FP16-friendly averaging. Roughly equivalent to:
# #
# norm_factor = ( # norm_factor = (
...@@ -157,6 +172,7 @@ def backbone_loss( ...@@ -157,6 +172,7 @@ def backbone_loss(
backbone_rigid_tensor: torch.Tensor, backbone_rigid_tensor: torch.Tensor,
backbone_rigid_mask: torch.Tensor, backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor, traj: torch.Tensor,
pair_mask: Optional[torch.Tensor] = None,
use_clamped_fape: Optional[torch.Tensor] = None, use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0, clamp_distance: float = 10.0,
loss_unit_distance: float = 10.0, loss_unit_distance: float = 10.0,
...@@ -184,6 +200,7 @@ def backbone_loss( ...@@ -184,6 +200,7 @@ def backbone_loss(
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[None].get_trans(), gt_aff[None].get_trans(),
backbone_rigid_mask[None], backbone_rigid_mask[None],
pair_mask=pair_mask,
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -196,6 +213,7 @@ def backbone_loss( ...@@ -196,6 +213,7 @@ def backbone_loss(
pred_aff.get_trans(), pred_aff.get_trans(),
gt_aff[None].get_trans(), gt_aff[None].get_trans(),
backbone_rigid_mask[None], backbone_rigid_mask[None],
pair_mask=pair_mask,
l1_clamp_distance=None, l1_clamp_distance=None,
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
...@@ -253,6 +271,7 @@ def sidechain_loss( ...@@ -253,6 +271,7 @@ def sidechain_loss(
sidechain_atom_pos, sidechain_atom_pos,
renamed_atom14_gt_positions, renamed_atom14_gt_positions,
renamed_atom14_gt_exists, renamed_atom14_gt_exists,
pair_mask=None,
l1_clamp_distance=clamp_distance, l1_clamp_distance=clamp_distance,
length_scale=length_scale, length_scale=length_scale,
eps=eps, eps=eps,
...@@ -266,10 +285,29 @@ def fape_loss( ...@@ -266,10 +285,29 @@ def fape_loss(
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
) -> torch.Tensor: ) -> torch.Tensor:
traj = out["sm"]["frames"]
asym_id = batch.get("asym_id")
if asym_id is not None:
intra_chain_mask = (asym_id[..., None] == asym_id[..., None, :]).to(dtype=traj.dtype)
intra_chain_bb_loss = backbone_loss(
traj=traj,
pair_mask=intra_chain_mask,
**{**batch, **config.intra_chain_backbone},
)
interface_bb_loss = backbone_loss(
traj=traj,
pair_mask=1. - intra_chain_mask,
**{**batch, **config.interface_backbone},
)
weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight
+ interface_bb_loss * config.interface_backbone.weight)
else:
bb_loss = backbone_loss( bb_loss = backbone_loss(
traj=out["sm"]["frames"], traj=traj,
**{**batch, **config.backbone}, **{**batch, **config.backbone},
) )
weighted_bb_loss = bb_loss * config.backbone.weight
sc_loss = sidechain_loss( sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"], out["sm"]["sidechain_frames"],
...@@ -277,7 +315,7 @@ def fape_loss( ...@@ -277,7 +315,7 @@ def fape_loss(
**{**batch, **config.sidechain}, **{**batch, **config.sidechain},
) )
loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss loss = weighted_bb_loss + config.sidechain.weight * sc_loss
# Average over the batch dimension # Average over the batch dimension
loss = torch.mean(loss) loss = torch.mean(loss)
...@@ -654,7 +692,7 @@ def compute_tm( ...@@ -654,7 +692,7 @@ def compute_tm(
n = residue_weights.shape[-1] n = residue_weights.shape[-1]
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32) pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
if interface: if interface:
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]) pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
predicted_tm_term *= pair_mask predicted_tm_term *= pair_mask
...@@ -891,6 +929,7 @@ def between_residue_clash_loss( ...@@ -891,6 +929,7 @@ def between_residue_clash_loss(
atom14_atom_exists: torch.Tensor, atom14_atom_exists: torch.Tensor,
atom14_atom_radius: torch.Tensor, atom14_atom_radius: torch.Tensor,
residue_index: torch.Tensor, residue_index: torch.Tensor,
asym_id: Optional[torch.Tensor] = None,
overlap_tolerance_soft=1.5, overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5, overlap_tolerance_hard=1.5,
eps=1e-10, eps=1e-10,
...@@ -966,9 +1005,13 @@ def between_residue_clash_loss( ...@@ -966,9 +1005,13 @@ def between_residue_clash_loss(
) )
n_one_hot = n_one_hot.type(fp_type) n_one_hot = n_one_hot.type(fp_type)
neighbour_mask = ( neighbour_mask = (residue_index[..., :, None] + 1) == residue_index[..., None, :]
residue_index[..., :, None, None, None] + 1
) == residue_index[..., None, :, None, None] if asym_id is not None:
neighbour_mask = neighbour_mask & (asym_id[..., :, None] == asym_id[..., None, :])
neighbour_mask = neighbour_mask[..., None, None]
c_n_bonds = ( c_n_bonds = (
neighbour_mask neighbour_mask
* c_one_hot[..., None, None, :, None] * c_one_hot[..., None, None, :, None]
...@@ -1010,7 +1053,7 @@ def between_residue_clash_loss( ...@@ -1010,7 +1053,7 @@ def between_residue_clash_loss(
# Compute the per atom loss sum. # Compute the per atom loss sum.
# shape (N, 14) # shape (N, 14)
per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum( per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
dists_to_low_error, axis=(-3, -1) dists_to_low_error, dim=(-3, -1)
) )
# Compute the hard clash mask. # Compute the hard clash mask.
...@@ -1019,17 +1062,20 @@ def between_residue_clash_loss( ...@@ -1019,17 +1062,20 @@ def between_residue_clash_loss(
dists < (dists_lower_bound - overlap_tolerance_hard) dists < (dists_lower_bound - overlap_tolerance_hard)
) )
per_atom_num_clash = torch.sum(clash_mask, dim=(-4, -2)) + torch.sum(clash_mask, dim=(-3, -1))
# Compute the per atom clash. # Compute the per atom clash.
# shape (N, 14) # shape (N, 14)
per_atom_clash_mask = torch.maximum( per_atom_clash_mask = torch.maximum(
torch.amax(clash_mask, axis=(-4, -2)), torch.amax(clash_mask, dim=(-4, -2)),
torch.amax(clash_mask, axis=(-3, -1)), torch.amax(clash_mask, dim=(-3, -1)),
) )
return { return {
"mean_loss": mean_loss, # shape () "mean_loss": mean_loss, # shape ()
"per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14) "per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
"per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14) "per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
"per_atom_num_clash": per_atom_num_clash # shape (N, 14)
} }
...@@ -1109,6 +1155,8 @@ def within_residue_violations( ...@@ -1109,6 +1155,8 @@ def within_residue_violations(
(dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound) (dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
) )
per_atom_num_clash = torch.sum(violations, dim=-2) + torch.sum(violations, dim=-1)
# Compute the per atom violations. # Compute the per atom violations.
per_atom_violations = torch.maximum( per_atom_violations = torch.maximum(
torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0] torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
...@@ -1117,6 +1165,7 @@ def within_residue_violations( ...@@ -1117,6 +1165,7 @@ def within_residue_violations(
return { return {
"per_atom_loss_sum": per_atom_loss_sum, "per_atom_loss_sum": per_atom_loss_sum,
"per_atom_violations": per_atom_violations, "per_atom_violations": per_atom_violations,
"per_atom_num_clash": per_atom_num_clash
} }
...@@ -1146,7 +1195,20 @@ def find_structural_violations( ...@@ -1146,7 +1195,20 @@ def find_structural_violations(
residue_constants.van_der_waals_radius[name[0]] residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types for name in residue_constants.atom_types
] ]
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius) atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
#TODO: Consolidate monomer/multimer modes
asym_id = batch.get("asym_id")
if asym_id is not None:
residx_atom14_to_atom37 = get_rc_tensor(
residue_constants.RESTYPE_ATOM14_TO_ATOM37, batch["aatype"]
)
atom14_atom_radius = (
batch["atom14_atom_exists"]
* atomtype_radius[residx_atom14_to_atom37.long()]
)
else:
atom14_atom_radius = ( atom14_atom_radius = (
batch["atom14_atom_exists"] batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]] * atomtype_radius[batch["residx_atom14_to_atom37"]]
...@@ -1158,6 +1220,7 @@ def find_structural_violations( ...@@ -1158,6 +1220,7 @@ def find_structural_violations(
atom14_atom_exists=batch["atom14_atom_exists"], atom14_atom_exists=batch["atom14_atom_exists"],
atom14_atom_radius=atom14_atom_radius, atom14_atom_radius=atom14_atom_radius,
residue_index=batch["residue_index"], residue_index=batch["residue_index"],
asym_id=asym_id,
overlap_tolerance_soft=clash_overlap_tolerance, overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=clash_overlap_tolerance, overlap_tolerance_hard=clash_overlap_tolerance,
) )
...@@ -1220,6 +1283,9 @@ def find_structural_violations( ...@@ -1220,6 +1283,9 @@ def find_structural_violations(
"clashes_per_atom_clash_mask": between_residue_clashes[ "clashes_per_atom_clash_mask": between_residue_clashes[
"per_atom_clash_mask" "per_atom_clash_mask"
], # (N, 14) ], # (N, 14)
"clashes_per_atom_num_clash": between_residue_clashes[
"per_atom_num_clash"
], # (N, 14)
}, },
"within_residues": { "within_residues": {
"per_atom_loss_sum": residue_violations[ "per_atom_loss_sum": residue_violations[
...@@ -1228,6 +1294,9 @@ def find_structural_violations( ...@@ -1228,6 +1294,9 @@ def find_structural_violations(
"per_atom_violations": residue_violations[ "per_atom_violations": residue_violations[
"per_atom_violations" "per_atom_violations"
], # (N, 14), ], # (N, 14),
"per_atom_num_clash": residue_violations[
"per_atom_num_clash"
], # (N, 14)
}, },
"total_per_residue_violations_mask": per_residue_violations_mask, # (N) "total_per_residue_violations_mask": per_residue_violations_mask, # (N)
} }
...@@ -1349,15 +1418,21 @@ def compute_violation_metrics_np( ...@@ -1349,15 +1418,21 @@ def compute_violation_metrics_np(
def violation_loss( def violation_loss(
violations: Dict[str, torch.Tensor], violations: Dict[str, torch.Tensor],
atom14_atom_exists: torch.Tensor, atom14_atom_exists: torch.Tensor,
average_clashes: bool = False,
eps=1e-6, eps=1e-6,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists) num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"] per_atom_clash = (violations["between_residues"]["clashes_per_atom_loss_sum"] +
+ violations["within_residues"]["per_atom_loss_sum"] violations["within_residues"]["per_atom_loss_sum"])
)
l_clash = l_clash / (eps + num_atoms) if average_clashes:
num_clash = (violations["between_residues"]["clashes_per_atom_num_clash"] +
violations["within_residues"]["per_atom_num_clash"])
per_atom_clash = per_atom_clash / (num_clash + eps)
l_clash = torch.sum(per_atom_clash) / (eps + num_atoms)
loss = ( loss = (
violations["between_residues"]["bonds_c_n_loss_mean"] violations["between_residues"]["bonds_c_n_loss_mean"]
+ violations["between_residues"]["angles_ca_c_n_loss_mean"] + violations["between_residues"]["angles_ca_c_n_loss_mean"]
...@@ -1533,6 +1608,64 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs ...@@ -1533,6 +1608,64 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
return loss return loss
def chain_center_of_mass_loss(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
asym_id: torch.Tensor,
clamp_distance: float = -4.0,
weight: float = 0.05,
eps: float = 1e-10
) -> torch.Tensor:
"""
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
Args:
all_atom_pred_pos:
[*, N_pts, 37, 3] All-atom predicted atom positions
all_atom_positions:
[*, N_pts, 37, 3] Ground truth all-atom positions
all_atom_mask:
[*, N_pts, 37] All-atom positions mask
asym_id:
[*, N_pts] Chain asym IDs
clamp_distance:
Cutoff above which distance errors are disregarded
weight:
Weight for loss
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True)
one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float()
def get_chain_center_of_mass(pos):
center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2)
centers = center_sum / (torch.sum(chain_pos_mask, dim=-1, keepdim=True) + eps)
return Vec3Array.from_array(centers)
pred_centers = get_chain_center_of_mass(all_atom_pred_pos) # [B, NC, 3]
true_centers = get_chain_center_of_mass(all_atom_positions) # [B, NC, 3]
pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps)
true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps)
losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2
loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :]
loss = masked_mean(loss_mask, losses, dim=(-1, -2))
return loss
class AlphaFoldLoss(nn.Module): class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement""" """Aggregation of the various losses described in the supplement"""
def __init__(self, config): def __init__(self, config):
...@@ -1585,7 +1718,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1585,7 +1718,7 @@ class AlphaFoldLoss(nn.Module):
), ),
"violation": lambda: violation_loss( "violation": lambda: violation_loss(
out["violation"], out["violation"],
**batch, **{**batch, **self.config.violation},
), ),
} }
...@@ -1595,6 +1728,12 @@ class AlphaFoldLoss(nn.Module): ...@@ -1595,6 +1728,12 @@ class AlphaFoldLoss(nn.Module):
**{**batch, **out, **self.config.tm}, **{**batch, **out, **self.config.tm},
) )
if (self.config.chain_center_of_mass.enabled):
loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss(
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.chain_center_of_mass},
)
cum_loss = 0. cum_loss = 0.
losses = {} losses = {}
for loss_name, loss_fn in loss_fns.items(): for loss_name, loss_fn in loss_fns.items():
......
...@@ -6,7 +6,7 @@ consts = mlc.ConfigDict( ...@@ -6,7 +6,7 @@ consts = mlc.ConfigDict(
"is_multimer": True, # monomer: False, multimer: True "is_multimer": True, # monomer: False, multimer: True
"chunk_size": 4, "chunk_size": 4,
"batch_size": 2, "batch_size": 2,
"n_res": 11, "n_res": 22,
"n_seq": 13, "n_seq": 13,
"n_templ": 3, "n_templ": 3,
"n_extra": 17, "n_extra": 17,
......
...@@ -29,14 +29,16 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4): ...@@ -29,14 +29,16 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
pieces = [] pieces = []
asym_ids = [] asym_ids = []
final_idx = n_chain - 1
for idx in range(n_chain - 1): for idx in range(n_chain - 1):
n_stop = (n_res - sum(pieces) - n_chain + idx - min_chain_len) n_stop = (n_res - sum(pieces) - n_chain + idx - min_chain_len)
if n_stop <= min_chain_len: if n_stop <= min_chain_len:
final_idx = idx
break break
piece = randint(min_chain_len, n_stop) piece = randint(min_chain_len, n_stop)
pieces.append(piece) pieces.append(piece)
asym_ids.extend(piece * [idx]) asym_ids.extend(piece * [idx])
asym_ids.extend((n_res - sum(pieces)) * [n_chain - 1]) asym_ids.extend((n_res - sum(pieces)) * [final_idx])
return np.array(asym_ids).astype(np.int64) return np.array(asym_ids).astype(np.int64)
......
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
import ml_collections as mlc import ml_collections as mlc
from openfold.data import data_transforms from openfold.data import data_transforms
from openfold.np import residue_constants
from openfold.utils.rigid_utils import ( from openfold.utils.rigid_utils import (
Rotation, Rotation,
Rigid, Rigid,
...@@ -42,6 +43,8 @@ from openfold.utils.loss import ( ...@@ -42,6 +43,8 @@ from openfold.utils.loss import (
sidechain_loss, sidechain_loss,
tm_loss, tm_loss,
compute_plddt, compute_plddt,
compute_tm,
chain_center_of_mass_loss
) )
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
...@@ -233,10 +236,23 @@ class TestLoss(unittest.TestCase): ...@@ -233,10 +236,23 @@ class TestLoss(unittest.TestCase):
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32) pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32) atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
atom_radius = np.random.rand(n_res, 14).astype(np.float32)
res_ind = np.arange( res_ind = np.arange(
n_res, n_res,
) )
residx_atom14_to_atom37 = np.random.randint(0, 37, (n_res, 14)).astype(np.int64)
atomtype_radius = [
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = np.array(atomtype_radius).astype(np.float32)
atom_radius = (
atom_exists
* atomtype_radius[residx_atom14_to_atom37]
)
asym_id = None
if consts.is_multimer:
asym_id = random_asym_ids(n_res) asym_id = random_asym_ids(n_res)
out_gt = f.apply( out_gt = f.apply(
...@@ -256,6 +272,7 @@ class TestLoss(unittest.TestCase): ...@@ -256,6 +272,7 @@ class TestLoss(unittest.TestCase):
torch.tensor(atom_exists).cuda(), torch.tensor(atom_exists).cuda(),
torch.tensor(atom_radius).cuda(), torch.tensor(atom_radius).cuda(),
torch.tensor(res_ind).cuda(), torch.tensor(res_ind).cuda(),
torch.tensor(asym_id).cuda() if asym_id is not None else None,
) )
out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro) out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
...@@ -279,6 +296,36 @@ class TestLoss(unittest.TestCase): ...@@ -279,6 +296,36 @@ class TestLoss(unittest.TestCase):
torch.max(torch.abs(out_gt - out_repro)) < consts.eps torch.max(torch.abs(out_gt - out_repro)) < consts.eps
) )
@compare_utils.skip_unless_alphafold_installed()
def test_compute_ptm_compare(self):
n_res = consts.n_res
max_bin = 31
no_bins = 64
logits = np.random.rand(n_res, n_res, no_bins)
boundaries = np.linspace(0, max_bin, num=(no_bins - 1))
ptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries)
ptm_gt = torch.tensor(ptm_gt)
logits_t = torch.tensor(logits)
ptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin)
self.assertTrue(
torch.max(torch.abs(ptm_gt - ptm_repro)) < consts.eps
)
if consts.is_multimer:
asym_id = random_asym_ids(n_res)
iptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries,
asym_id=asym_id, interface=True)
iptm_gt = torch.tensor(iptm_gt)
iptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin,
asym_id=torch.tensor(asym_id), interface=True)
self.assertTrue(
torch.max(torch.abs(iptm_gt - iptm_repro)) < consts.eps
)
def test_find_structural_violations(self): def test_find_structural_violations(self):
n = consts.n_res n = consts.n_res
...@@ -335,9 +382,11 @@ class TestLoss(unittest.TestCase): ...@@ -335,9 +382,11 @@ class TestLoss(unittest.TestCase):
"residx_atom14_to_atom37": np.random.randint( "residx_atom14_to_atom37": np.random.randint(
0, 37, (n_res, 14) 0, 37, (n_res, 14)
).astype(np.int64), ).astype(np.int64),
"asym_id": random_asym_ids(n_res)
} }
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
pred_pos = np.random.rand(n_res, 14, 3) pred_pos = np.random.rand(n_res, 14, 3)
config = mlc.ConfigDict( config = mlc.ConfigDict(
...@@ -632,6 +681,40 @@ class TestLoss(unittest.TestCase): ...@@ -632,6 +681,40 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_violation_loss(self):
config = compare_utils.get_alphafold_config()
c_viol = config.model.heads.structure_module
n_res = consts.n_res
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
batch = data_transforms.make_atom14_masks(batch)
loss_sum_clash = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
average_clashes=False, **batch
)
loss_sum_clash = loss_sum_clash.cpu()
loss_avg_clash = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
average_clashes=True, **batch
)
loss_avg_clash = loss_avg_clash.cpu()
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_violation_loss_compare(self): def test_violation_loss_compare(self):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
...@@ -680,10 +763,12 @@ class TestLoss(unittest.TestCase): ...@@ -680,10 +763,12 @@ class TestLoss(unittest.TestCase):
batch = { batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32), "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res), "residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)), "aatype": np.random.randint(0, 21, (n_res,))
"asym_id": random_asym_ids(n_res)
} }
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32) atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
alphafold.model.tf.data_transforms.make_atom14_masks(batch) alphafold.model.tf.data_transforms.make_atom14_masks(batch)
...@@ -801,8 +886,7 @@ class TestLoss(unittest.TestCase): ...@@ -801,8 +886,7 @@ class TestLoss(unittest.TestCase):
"backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype( "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
np.float32 np.float32
), ),
"use_clamped_fape": np.array(0.0), "use_clamped_fape": np.array(0.0)
"asym_id": random_asym_ids(n_res)
} }
value = { value = {
...@@ -814,6 +898,9 @@ class TestLoss(unittest.TestCase): ...@@ -814,6 +898,9 @@ class TestLoss(unittest.TestCase):
), ),
} }
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
out_gt = f.apply({}, None, batch, value) out_gt = f.apply({}, None, batch, value)
out_gt = torch.tensor(np.array(out_gt.block_until_ready())) out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
...@@ -826,6 +913,16 @@ class TestLoss(unittest.TestCase): ...@@ -826,6 +913,16 @@ class TestLoss(unittest.TestCase):
) )
batch["backbone_rigid_mask"] = batch["backbone_affine_mask"] batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
if consts.is_multimer:
intra_chain_mask = (batch["asym_id"][..., None]
== batch["asym_id"][..., None, :]).to(dtype=value["traj"].dtype)
intra_chain_out = backbone_loss(traj=value["traj"], pair_mask=intra_chain_mask,
**{**batch, **c_sm.intra_chain_fape})
interface_out = backbone_loss(traj=value["traj"], pair_mask=1. - intra_chain_mask,
**{**batch, **c_sm.interface_fape})
out_repro = intra_chain_out + interface_out
out_repro = out_repro.cpu()
else:
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm}) out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
...@@ -869,14 +966,14 @@ class TestLoss(unittest.TestCase): ...@@ -869,14 +966,14 @@ class TestLoss(unittest.TestCase):
v["sidechains"] = {} v["sidechains"] = {}
v["sidechains"][ v["sidechains"][
"frames" "frames"
] = alphafold.model.r3.rigids_from_tensor4x4( ] = self.am_rigid.rigids_from_tensor4x4(
value["sidechains"]["frames"] value["sidechains"]["frames"]
) )
v["sidechains"]["atom_pos"] = alphafold.model.r3.vecs_from_tensor( v["sidechains"]["atom_pos"] = self.am_rigid.vecs_from_tensor(
value["sidechains"]["atom_pos"] value["sidechains"]["atom_pos"]
) )
v.update( v.update(
alphafold.model.folding.compute_renamed_ground_truth( self.am_fold.compute_renamed_ground_truth(
batch, batch,
atom14_pred_positions, atom14_pred_positions,
) )
...@@ -907,9 +1004,6 @@ class TestLoss(unittest.TestCase): ...@@ -907,9 +1004,6 @@ class TestLoss(unittest.TestCase):
), ),
} }
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
def _build_extra_feats_np(): def _build_extra_feats_np():
b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
b = data_transforms.make_atom14_masks(b) b = data_transforms.make_atom14_masks(b)
...@@ -950,7 +1044,7 @@ class TestLoss(unittest.TestCase): ...@@ -950,7 +1044,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(consts.is_multimer or "ptm" not in consts.model, "Not enabled for non-ptm models.") @unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
def test_tm_loss_compare(self): def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error c_tm = config.model.heads.predicted_aligned_error
...@@ -1017,6 +1111,33 @@ class TestLoss(unittest.TestCase): ...@@ -1017,6 +1111,33 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_chain_center_of_mass_loss(self):
batch_size = consts.batch_size
n_res = consts.n_res
batch = {
"all_atom_positions": np.random.rand(batch_size, n_res, 37, 3).astype(np.float32) * 10.0,
"all_atom_mask": np.random.randint(0, 2, (batch_size, n_res, 37)).astype(np.float32),
"asym_id": np.stack([random_asym_ids(n_res) for _ in range(batch_size)])
}
config = {
"weight": 0.05,
"clamp_distance": -4.0,
}
final_atom_positions = torch.rand(batch_size, n_res, 37, 3).cuda()
to_tensor = lambda t: torch.tensor(t).cuda()
batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = chain_center_of_mass_loss(
all_atom_pred_pos=final_atom_positions,
**{**batch, **config},
)
out_repro = out_repro.cpu()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment