Commit 4666e15e authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update codes

parent 938782e0
...@@ -1708,7 +1708,7 @@ def kabsch_rotation(P, Q): ...@@ -1708,7 +1708,7 @@ def kabsch_rotation(P, Q):
# right-handed coordinate system. # right-handed coordinate system.
# And finally calculating the optimal rotation matrix U # And finally calculating the optimal rotation matrix U
# see http://en.wikipedia.org/wiki/Kabsch_algorithm # see http://en.wikipedia.org/wiki/Kabsch_algorithm
V, _, W = torch.linalg.svd(C) V, _, W = torch.linalg.svd(C.to('cpu'))
d = (torch.linalg.det(V) * torch.linalg.det(W)) < 0.0 d = (torch.linalg.det(V) * torch.linalg.det(W)) < 0.0
if d: if d:
...@@ -1855,7 +1855,6 @@ def greedy_align( ...@@ -1855,7 +1855,6 @@ def greedy_align(
assert best_idx is not None assert best_idx is not None
used[best_idx] = True used[best_idx] = True
align.append((i, best_idx)) align.append((i, best_idx))
print(f"align is {align}")
return align return align
...@@ -1900,7 +1899,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1900,7 +1899,7 @@ class AlphaFoldLoss(nn.Module):
def loss(self, out, batch, _return_breakdown=False): def loss(self, out, batch, _return_breakdown=False):
""" """
Rename previous forward() as loss Rename previous forward() as loss()
so that can be reused in the subclass so that can be reused in the subclass
""" """
if "violation" not in out.keys(): if "violation" not in out.keys():
...@@ -2034,7 +2033,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2034,7 +2033,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask] per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
logger.info(f"anchor_gt_asym is chosen to be: {anchor_gt_asym}") print(f"anchor_gt_asym is chosen to be: {anchor_gt_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
best_rmsd = 1e20 best_rmsd = 1e20
...@@ -2091,7 +2090,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2091,7 +2090,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
if rmsd < best_rmsd: if rmsd < best_rmsd:
best_rmsd = rmsd best_rmsd = rmsd
best_labels = merged_labels best_labels = merged_labels
print(f"finished shuffling and final align is {align}")
return best_labels return best_labels
def forward(self,out,batch,_return_breakdown=False): def forward(self,out,batch,_return_breakdown=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment