"runtime/tests/python/integration/test_direct.py" did not exist on "c3b847901099bf5c3dd174a3c8ec994b73426833"
Commit 9f24ebf0 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fixed the overflow problems while slicing matrices

parent d2eae833
......@@ -2047,7 +2047,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"],asym_mask)
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
anchor_gt_idx = int(anchor_gt_asym) - 1
......@@ -2060,12 +2060,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_residue_idx)
anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
anchor_pred_mask = pred_ca_mask[asym_mask]
input_mask = (anchor_true_mask.to('cuda:0') * anchor_pred_mask.to('cuda:0')).bool()
# anchor_pred_pos = anchor_pred_pos.to('cuda')
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_residue_idx)
anchor_pred_mask =pred_ca_mask[asym_mask]
# anchor_pred_mask = anchor_pred_mask.to('cuda')
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform(
anchor_true_pos,
anchor_pred_pos,mask=input_mask
......@@ -2075,7 +2077,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align(
batch,
per_asym_residue_index,
......@@ -2117,7 +2119,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
permutated_labels.pop('aatype')
features.update(permutated_labels)
move_to_cpu = lambda t: (t.to('cpu'))
features = tensor_tree_map(move_to_cpu,features)
# features = tensor_tree_map(move_to_cpu,features)
if (not _return_breakdown):
cum_loss = self.loss(out,features,_return_breakdown)
print(f"cum_loss: {cum_loss}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment