Commit 33941e46 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix loss bugs

parent 1bc68426
......@@ -211,7 +211,7 @@ config = mlc.ConfigDict({
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.,
"num_bins": 50,
"no_bins": 50,
"eps": 1e-10,
"weight": 0.01,
},
......
......@@ -49,7 +49,7 @@ class AuxiliaryHeads(nn.Module):
def forward(self, outputs):
aux_out = {}
lddt_logits = self.plddt(outputs["single"])
lddt_logits = self.plddt(outputs["sm"]["single"])
aux_out["lddt_logits"] = lddt_logits
# Required for relaxation later on
......
......@@ -751,7 +751,7 @@ class StructureModule(nn.Module):
t = t.stop_rot_gradient()
outputs = stack_tensor_dicts(outputs)
outputs["single_act"] = s
outputs["single"] = s
return outputs
......
......@@ -74,6 +74,7 @@ def get_chi_atom_indices():
def compute_residx(batch):
float_type = batch["seq_mask"].dtype
aatype = batch["aatype"]
restype_atom14_to_atom37 = [] # mapping (restype, atom37) --> atom14
......@@ -104,34 +105,28 @@ def compute_residx(batch):
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
residx_atom14_to_atom37 = np.take_along_axis(
restype_atom14_to_atom37,
aatype[..., None],
axis=0
restype_atom14_to_atom37 = aatype.new_tensor(
restype_atom14_to_atom37
)
restype_atom37_to_atom14 = aatype.new_tensor(
restype_atom37_to_atom14
)
residx_atom14_mask = np.take_along_axis(
restype_atom14_mask,
aatype[..., None],
axis=0,
restype_atom14_mask = aatype.new_tensor(
restype_atom14_mask, dtype=float_type
)
residx_atom14_to_atom37 = restype_atom14_to_atom37[aatype]
residx_atom14_mask = restype_atom14_mask[aatype]
batch['atom14_atom_exists'] = residx_atom14_mask
batch['residx_atom14_to_atom37'] = residx_atom14_to_atom37.long()
batch['residx_atom14_to_atom37'] = residx_atom14_to_atom37
# create the gather indices for mapping back
residx_atom37_to_atom14 = np.take_along_axis(
restype_atom37_to_atom14,
aatype[..., None],
axis=0,
)
batch['residx_atom37_to_atom14'] = residx_atom37_to_atom14.long()
residx_atom37_to_atom14 = restype_atom37_to_atom14[aatype]
batch['residx_atom37_to_atom14'] = residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_mask = torch.zeros([21, 37], dtype=float_type)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
......@@ -139,11 +134,7 @@ def compute_residx(batch):
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = np.take_along_axis(
restype_atom37_mask,
aatype[..., None],
axis=0,
)
residx_atom37_mask = restype_atom37_mask[aatype]
batch['atom37_atom_exists'] = residx_atom37_mask
......
......@@ -27,6 +27,7 @@ from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
masked_mean,
permute_final_dims,
)
......@@ -289,28 +290,25 @@ def lddt_loss(
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.,
num_bins: int = 50,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
**kwargs,
) -> torch.Tensor:
all_atom_positions = batch["all_atom_positions"]
all_atom_mask = batch["all_atom_mask"]
n = all_atom_mask.shape[-1]
n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., :, ca_pos, :]
all_atom_positions = all_atom_positions[..., :, ca_pos, :]
all_atom_mask = all_atom_mask[..., :, ca_pos:(ca_pos + 1)] # keep dim
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos:(ca_pos + 1)] # keep dim
dmat_true = torch.sqrt(
eps +
torch.sum(
(
all_atom_positions[..., None] -
all_atom_positions[..., None, :]
all_atom_positions[..., None, :] -
all_atom_positions[..., None, :, :]
)**2,
dim=-1,
)
......@@ -320,14 +318,13 @@ def lddt_loss(
eps +
torch.sum(
(
all_atom_pred_pos[..., None] -
all_atom_pred_pos[..., None, :]
all_atom_pred_pos[..., None, :] -
all_atom_pred_pos[..., None, :, :]
)**2,
dim=-1,
)
)
dists_to_score = (
(dmat_true < cutoff) * all_atom_mask *
permute_final_dims(all_atom_mask, 1, 0) *
......@@ -337,29 +334,31 @@ def lddt_loss(
dist_l1 = torch.abs(dmat_true - dmat_pred)
score = (
(dist_l1 < 0.5) +
(dist_l1 < 1.0) +
(dist_l1 < 2.0) +
(dist_l1 < 4.0)
(dist_l1 < 0.5).type(dist_l1.dtype) +
(dist_l1 < 1.0).type(dist_l1.dtype) +
(dist_l1 < 2.0).type(dist_l1.dtype) +
(dist_l1 < 4.0).type(dist_l1.dtype)
)
score *= 0.25
norm = 1. / (eps + torch.sum(dists_to_score, dim=-1))
score = norm * (eps + torch.sum(dists_to_score * score, dim=-1))
# TODO: this feels a bit weird, but it's in the source
score = score.detach()
bin_index = torch.floor(lddt_ca * num_bins)
bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
lddt_ca_one_hot = torch.nn.functional.one_hot(
bin_index, num_classes=num_bins
bin_index, num_classes=no_bins
)
errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
loss = torch.sum(errors * all_atom_mask) / (torch.sum(mask_ca) + eps)
all_atom_mask = all_atom_mask.squeeze(-1)
loss = (
torch.sum(errors * all_atom_mask) / (torch.sum(all_atom_mask) + 1e-8)
)
loss *= (
(resolution >= min_resolution) &
(resolution <= max_resolution)
......@@ -917,17 +916,16 @@ def find_structural_violations(
overlap_tolerance=clash_overlap_tolerance,
bond_length_tolerance_factor=violation_tolerance_factor
)
atom14_dists_lower_bound = restype_atom14_bounds["lower_bound"][
batch["aatype"]
]
atom14_dists_upper_bound = restype_atom14_bounds["upper_bound"][
batch["aatype"]
]
atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
atom14_dists_lower_bound
atom14_atom_exists = batch["atom14_atom_exists"]
atom14_dists_lower_bound = (
atom14_pred_positions.new_tensor(restype_atom14_bounds["lower_bound"])[
batch["aatype"]
]
)
atom14_dists_upper_bound = atom14_pred_positions.new_tensor(
atom14_dists_upper_bound
atom14_dists_upper_bound = (
atom14_pred_positions.new_tensor(restype_atom14_bounds["upper_bound"])[
batch["aatype"]
]
)
residue_violations = within_residue_violations(
atom14_pred_positions=atom14_pred_positions,
......@@ -1102,6 +1100,7 @@ def violation_loss(
violations: Dict[str, torch.Tensor],
atom14_atom_exists: torch.Tensor,
eps=1e-6,
**kwargs,
) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
......
......@@ -49,7 +49,7 @@ def dict_multimap(fn, dicts):
for k, v in first.items():
all_v = [d[k] for d in dicts]
if(type(v) is dict):
new_dict[k] = dict_multimap(all_v)
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
......@@ -122,9 +122,9 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are interpreted as simplified "pytrees,"
consisting only of (nested) lists, tuples, and dicts with tensor
leaves.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment