Commit 33941e46 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix loss bugs

parent 1bc68426
...@@ -211,7 +211,7 @@ config = mlc.ConfigDict({ ...@@ -211,7 +211,7 @@ config = mlc.ConfigDict({
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"cutoff": 15., "cutoff": 15.,
"num_bins": 50, "no_bins": 50,
"eps": 1e-10, "eps": 1e-10,
"weight": 0.01, "weight": 0.01,
}, },
......
...@@ -49,7 +49,7 @@ class AuxiliaryHeads(nn.Module): ...@@ -49,7 +49,7 @@ class AuxiliaryHeads(nn.Module):
def forward(self, outputs): def forward(self, outputs):
aux_out = {} aux_out = {}
lddt_logits = self.plddt(outputs["single"]) lddt_logits = self.plddt(outputs["sm"]["single"])
aux_out["lddt_logits"] = lddt_logits aux_out["lddt_logits"] = lddt_logits
# Required for relaxation later on # Required for relaxation later on
......
...@@ -751,7 +751,7 @@ class StructureModule(nn.Module): ...@@ -751,7 +751,7 @@ class StructureModule(nn.Module):
t = t.stop_rot_gradient() t = t.stop_rot_gradient()
outputs = stack_tensor_dicts(outputs) outputs = stack_tensor_dicts(outputs)
outputs["single_act"] = s outputs["single"] = s
return outputs return outputs
......
...@@ -74,6 +74,7 @@ def get_chi_atom_indices(): ...@@ -74,6 +74,7 @@ def get_chi_atom_indices():
def compute_residx(batch): def compute_residx(batch):
float_type = batch["seq_mask"].dtype
aatype = batch["aatype"] aatype = batch["aatype"]
restype_atom14_to_atom37 = [] # mapping (restype, atom37) --> atom14 restype_atom14_to_atom37 = [] # mapping (restype, atom37) --> atom14
...@@ -104,34 +105,28 @@ def compute_residx(batch): ...@@ -104,34 +105,28 @@ def compute_residx(batch):
restype_atom37_to_atom14.append([0] * 37) restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14) restype_atom14_mask.append([0.] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) restype_atom14_to_atom37 = aatype.new_tensor(
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) restype_atom14_to_atom37
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) )
restype_atom37_to_atom14 = aatype.new_tensor(
residx_atom14_to_atom37 = np.take_along_axis( restype_atom37_to_atom14
restype_atom14_to_atom37,
aatype[..., None],
axis=0
) )
residx_atom14_mask = np.take_along_axis( restype_atom14_mask = aatype.new_tensor(
restype_atom14_mask, restype_atom14_mask, dtype=float_type
aatype[..., None],
axis=0,
) )
residx_atom14_to_atom37 = restype_atom14_to_atom37[aatype]
residx_atom14_mask = restype_atom14_mask[aatype]
batch['atom14_atom_exists'] = residx_atom14_mask batch['atom14_atom_exists'] = residx_atom14_mask
batch['residx_atom14_to_atom37'] = residx_atom14_to_atom37.long() batch['residx_atom14_to_atom37'] = residx_atom14_to_atom37
# create the gather indices for mapping back # create the gather indices for mapping back
residx_atom37_to_atom14 = np.take_along_axis( residx_atom37_to_atom14 = restype_atom37_to_atom14[aatype]
restype_atom37_to_atom14, batch['residx_atom37_to_atom14'] = residx_atom37_to_atom14
aatype[..., None],
axis=0,
)
batch['residx_atom37_to_atom14'] = residx_atom37_to_atom14.long()
# create the corresponding mask # create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) restype_atom37_mask = torch.zeros([21, 37], dtype=float_type)
for restype, restype_letter in enumerate(residue_constants.restypes): for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter] restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name] atom_names = residue_constants.residue_atoms[restype_name]
...@@ -139,11 +134,7 @@ def compute_residx(batch): ...@@ -139,11 +134,7 @@ def compute_residx(batch):
atom_type = residue_constants.atom_order[atom_name] atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1 restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = np.take_along_axis( residx_atom37_mask = restype_atom37_mask[aatype]
restype_atom37_mask,
aatype[..., None],
axis=0,
)
batch['atom37_atom_exists'] = residx_atom37_mask batch['atom37_atom_exists'] = residx_atom37_mask
......
...@@ -27,6 +27,7 @@ from openfold.utils.tensor_utils import ( ...@@ -27,6 +27,7 @@ from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
masked_mean, masked_mean,
permute_final_dims,
) )
...@@ -289,28 +290,25 @@ def lddt_loss( ...@@ -289,28 +290,25 @@ def lddt_loss(
all_atom_mask: torch.Tensor, all_atom_mask: torch.Tensor,
resolution: torch.Tensor, resolution: torch.Tensor,
cutoff: float = 15., cutoff: float = 15.,
num_bins: int = 50, no_bins: int = 50,
min_resolution: float = 0.1, min_resolution: float = 0.1,
max_resolution: float = 3.0, max_resolution: float = 3.0,
eps: float = 1e-10, eps: float = 1e-10,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
all_atom_positions = batch["all_atom_positions"] n = all_atom_mask.shape[-2]
all_atom_mask = batch["all_atom_mask"]
n = all_atom_mask.shape[-1]
ca_pos = residue_constants.atom_order["CA"] ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., :, ca_pos, :] all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., :, ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., :, ca_pos:(ca_pos + 1)] # keep dim all_atom_mask = all_atom_mask[..., ca_pos:(ca_pos + 1)] # keep dim
dmat_true = torch.sqrt( dmat_true = torch.sqrt(
eps + eps +
torch.sum( torch.sum(
( (
all_atom_positions[..., None] - all_atom_positions[..., None, :] -
all_atom_positions[..., None, :] all_atom_positions[..., None, :, :]
)**2, )**2,
dim=-1, dim=-1,
) )
...@@ -320,14 +318,13 @@ def lddt_loss( ...@@ -320,14 +318,13 @@ def lddt_loss(
eps + eps +
torch.sum( torch.sum(
( (
all_atom_pred_pos[..., None] - all_atom_pred_pos[..., None, :] -
all_atom_pred_pos[..., None, :] all_atom_pred_pos[..., None, :, :]
)**2, )**2,
dim=-1, dim=-1,
) )
) )
dists_to_score = ( dists_to_score = (
(dmat_true < cutoff) * all_atom_mask * (dmat_true < cutoff) * all_atom_mask *
permute_final_dims(all_atom_mask, 1, 0) * permute_final_dims(all_atom_mask, 1, 0) *
...@@ -337,29 +334,31 @@ def lddt_loss( ...@@ -337,29 +334,31 @@ def lddt_loss(
dist_l1 = torch.abs(dmat_true - dmat_pred) dist_l1 = torch.abs(dmat_true - dmat_pred)
score = ( score = (
(dist_l1 < 0.5) + (dist_l1 < 0.5).type(dist_l1.dtype) +
(dist_l1 < 1.0) + (dist_l1 < 1.0).type(dist_l1.dtype) +
(dist_l1 < 2.0) + (dist_l1 < 2.0).type(dist_l1.dtype) +
(dist_l1 < 4.0) (dist_l1 < 4.0).type(dist_l1.dtype)
) )
score *= 0.25 score *= 0.25
norm = 1. / (eps + torch.sum(dists_to_score, dim=-1)) norm = 1. / (eps + torch.sum(dists_to_score, dim=-1))
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1)) score = norm * (eps + torch.sum(dists_to_score * score, dim=-1))
# TODO: this feels a bit weird, but it's in the source
score = score.detach() score = score.detach()
bin_index = torch.floor(lddt_ca * num_bins) bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
lddt_ca_one_hot = torch.nn.functional.one_hot( lddt_ca_one_hot = torch.nn.functional.one_hot(
bin_index, num_classes=num_bins bin_index, num_classes=no_bins
) )
errors = softmax_cross_entropy(logits, lddt_ca_one_hot) errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
loss = torch.sum(errors * all_atom_mask) / (torch.sum(mask_ca) + eps) all_atom_mask = all_atom_mask.squeeze(-1)
loss = (
torch.sum(errors * all_atom_mask) / (torch.sum(all_atom_mask) + 1e-8)
)
loss *= ( loss *= (
(resolution >= min_resolution) & (resolution >= min_resolution) &
(resolution <= max_resolution) (resolution <= max_resolution)
...@@ -917,17 +916,16 @@ def find_structural_violations( ...@@ -917,17 +916,16 @@ def find_structural_violations(
overlap_tolerance=clash_overlap_tolerance, overlap_tolerance=clash_overlap_tolerance,
bond_length_tolerance_factor=violation_tolerance_factor bond_length_tolerance_factor=violation_tolerance_factor
) )
atom14_dists_lower_bound = restype_atom14_bounds["lower_bound"][ atom14_atom_exists = batch["atom14_atom_exists"]
batch["aatype"] atom14_dists_lower_bound = (
] atom14_pred_positions.new_tensor(restype_atom14_bounds["lower_bound"])[
atom14_dists_upper_bound = restype_atom14_bounds["upper_bound"][ batch["aatype"]
batch["aatype"] ]
]
atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
atom14_dists_lower_bound
) )
atom14_dists_upper_bound = atom14_pred_positions.new_tensor( atom14_dists_upper_bound = (
atom14_dists_upper_bound atom14_pred_positions.new_tensor(restype_atom14_bounds["upper_bound"])[
batch["aatype"]
]
) )
residue_violations = within_residue_violations( residue_violations = within_residue_violations(
atom14_pred_positions=atom14_pred_positions, atom14_pred_positions=atom14_pred_positions,
...@@ -1102,6 +1100,7 @@ def violation_loss( ...@@ -1102,6 +1100,7 @@ def violation_loss(
violations: Dict[str, torch.Tensor], violations: Dict[str, torch.Tensor],
atom14_atom_exists: torch.Tensor, atom14_atom_exists: torch.Tensor,
eps=1e-6, eps=1e-6,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists) num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum( l_clash = torch.sum(
......
...@@ -49,7 +49,7 @@ def dict_multimap(fn, dicts): ...@@ -49,7 +49,7 @@ def dict_multimap(fn, dicts):
for k, v in first.items(): for k, v in first.items():
all_v = [d[k] for d in dicts] all_v = [d[k] for d in dicts]
if(type(v) is dict): if(type(v) is dict):
new_dict[k] = dict_multimap(all_v) new_dict[k] = dict_multimap(fn, all_v)
else: else:
new_dict[k] = fn(all_v) new_dict[k] = fn(all_v)
...@@ -122,9 +122,9 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -122,9 +122,9 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
""" """
Implements the "chunking" procedure described in section 1.11.8. Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are interpreted as simplified "pytrees," Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (nested) lists, tuples, and dicts with tensor consisting only of (arbitrarily nested) lists, tuples, and dicts with
leaves. torch.Tensor leaves.
Args: Args:
layer: layer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment