Commit 9f24ebf0 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fixed the overflow problems while slicing matrices

parent d2eae833
...@@ -2047,7 +2047,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2047,7 +2047,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index = {} per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool() asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask] per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"],asym_mask)
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
...@@ -2060,12 +2060,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2060,12 +2060,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool() asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)] anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_residue_idx)
anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
anchor_pred_pos = pred_ca_pos[asym_mask] anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
anchor_pred_mask = pred_ca_mask[asym_mask] # anchor_pred_pos = anchor_pred_pos.to('cuda')
input_mask = (anchor_true_mask.to('cuda:0') * anchor_pred_mask.to('cuda:0')).bool() anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_residue_idx)
anchor_pred_mask =pred_ca_mask[asym_mask]
# anchor_pred_mask = anchor_pred_mask.to('cuda')
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform( r, x = get_optimal_transform(
anchor_true_pos, anchor_true_pos,
anchor_pred_pos,mask=input_mask anchor_pred_pos,mask=input_mask
...@@ -2075,7 +2077,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2075,7 +2077,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del anchor_true_mask del anchor_true_mask
gc.collect() gc.collect()
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align( align = greedy_align(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
...@@ -2117,7 +2119,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2117,7 +2119,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
permutated_labels.pop('aatype') permutated_labels.pop('aatype')
features.update(permutated_labels) features.update(permutated_labels)
move_to_cpu = lambda t: (t.to('cpu')) move_to_cpu = lambda t: (t.to('cpu'))
features = tensor_tree_map(move_to_cpu,features) # features = tensor_tree_map(move_to_cpu,features)
if (not _return_breakdown): if (not _return_breakdown):
cum_loss = self.loss(out,features,_return_breakdown) cum_loss = self.loss(out,features,_return_breakdown)
print(f"cum_loss: {cum_loss}") print(f"cum_loss: {cum_loss}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment