Commit 66a60d58 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

start modifying mderge_label function to make it compatible with dataloader inputs

parent e9794a62
...@@ -1610,6 +1610,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs ...@@ -1610,6 +1610,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
Returns: Returns:
Masked MSA loss Masked MSA loss
""" """
print(f"logits shape: {logits.shape} true_msa shape:{true_msa.shape}")
errors = softmax_cross_entropy( errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes) logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes)
) )
...@@ -1749,7 +1750,7 @@ def get_optimal_transform( ...@@ -1749,7 +1750,7 @@ def get_optimal_transform(
src_atoms = torch.zeros((1, 3), device=src_atoms.device).float() src_atoms = torch.zeros((1, 3), device=src_atoms.device).float()
tgt_atoms = src_atoms tgt_atoms = src_atoms
else: else:
src_atoms = src_atoms[mask, :] src_atoms = src_atoms.to('cuda:0')[mask, :]
tgt_atoms = tgt_atoms.to('cuda:0')[mask, :] tgt_atoms = tgt_atoms.to('cuda:0')[mask, :]
src_center = src_atoms.mean(-2, keepdim=True) src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True) tgt_center = tgt_atoms.mean(-2, keepdim=True)
...@@ -1857,7 +1858,6 @@ def greedy_align( ...@@ -1857,7 +1858,6 @@ def greedy_align(
best_idx = None best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)] cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)] cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask] cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask] cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list: for next_asym_id in cur_asym_list:
...@@ -1890,27 +1890,30 @@ def merge_labels(batch, per_asym_residue_index, labels, align): ...@@ -1890,27 +1890,30 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
num_res = batch["msa_mask"].shape[-1] num_res = batch["msa_mask"].shape[-1]
outs = {} outs = {}
for k, v in labels[0].items(): for k, v in labels[0].items():
if k in [
"resolution",
]:
continue
cur_out = {} cur_out = {}
for i, j in align: for i, j in align:
label = labels[j][k] label = labels[j][k]
# to 1-based # to 1-based
cur_residue_index = per_asym_residue_index[i + 1] cur_residue_index = per_asym_residue_index[i + 1]
cur_out[i] = label[cur_residue_index] if len(v.shape)==0 or "template" in k:
continue
else:
cur_out[i] = label[cur_residue_index]
cur_out = [x[1] for x in sorted(cur_out.items())] cur_out = [x[1] for x in sorted(cur_out.items())]
new_v = torch.concat(cur_out, dim=0) if len(cur_out)>0:
merged_nres = new_v.shape[0] new_v = torch.concat(cur_out, dim=0)
assert ( merged_nres = new_v.shape[0]
merged_nres <= num_res assert (
), f"bad merged num res: {merged_nres} > {num_res}. something is wrong." merged_nres <= num_res
if merged_nres < num_res: # must pad ), f"bad merged num res: {merged_nres} > {num_res}. something is wrong."
pad_dim = new_v.shape[1:] if merged_nres < num_res: # must pad
pad_v = new_v.new_zeros((num_res - merged_nres, *pad_dim)) pad_dim = new_v.shape[1:]
new_v = torch.concat((new_v, pad_v), dim=0) pad_v = new_v.new_zeros((num_res - merged_nres, *pad_dim))
outs[k] = new_v new_v = torch.concat((new_v, pad_v), dim=0)
outs[k] = new_v
print(f"finished merging")
for k,v in outs.items():
print(f"{k}:{v.shape}")
return outs return outs
...@@ -2050,7 +2053,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2050,7 +2053,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks = [ true_ca_masks = [
l["all_atom_mask"][..., ca_idx].float() for l in labels l["all_atom_mask"][..., ca_idx].float() for l in labels
] # list([nres,]) ] # list([nres,])
unique_asym_ids = torch.unique(batch["asym_id"]) unique_asym_ids = torch.unique(batch["asym_id"])
per_asym_residue_index = {} per_asym_residue_index = {}
...@@ -2059,7 +2061,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2059,7 +2061,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask] per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym is : {anchor_gt_asym} and anchor_pred_asym is {anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"]) unique_entity_ids = torch.unique(batch["entity_id"])
...@@ -2100,15 +2101,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2100,15 +2101,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del aligned_true_ca_poses del aligned_true_ca_poses
del r,x del r,x
gc.collect() gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
merged_labels = merge_labels( merged_labels = merge_labels(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
labels, labels,
align, align,
) )
print(f"finished multi-chain permutation and final align is {align}")
return merged_labels return merged_labels
...@@ -2122,9 +2121,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2122,9 +2121,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
features,labels = batch features,labels = batch
features['resolution'] = labels[2]['resolution'] # firstly update the resolution feature
# first remove the recycling dimention of input features # first remove the recycling dimention of input features
features = tensor_tree_map(lambda t: t[..., -1], features) features = tensor_tree_map(lambda t: t[..., -1], features)
features['resolution'] = labels[0]['resolution']
# then permutate ground truth chains before calculating the loss # then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels) permutated_labels = self.multi_chain_perm_align(out,features,labels)
permutated_labels.pop('aatype') permutated_labels.pop('aatype')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment