"vscode:/vscode.git/clone" did not exist on "fcba33580eab1c6dc559a3f245cd742538e1a944"
Commit 3ab9da6e authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

move some tensors back to gpu

parent a420160f
...@@ -1704,6 +1704,7 @@ def kabsch_rotation(P, Q): ...@@ -1704,6 +1704,7 @@ def kabsch_rotation(P, Q):
# Will continue trying SVD until the optimal rotaion is calculated # Will continue trying SVD until the optimal rotaion is calculated
# # # #
try: try:
# first need to load P and Q to cpu otherwise cannot extract the numpy matrices
rotation = procrustes.rotational(P.to('cpu').numpy(), rotation = procrustes.rotational(P.to('cpu').numpy(),
Q.to('cpu').numpy(),translate=True) Q.to('cpu').numpy(),translate=True)
finished_rotation = True finished_rotation = True
...@@ -1736,12 +1737,12 @@ def get_optimal_transform( ...@@ -1736,12 +1737,12 @@ def get_optimal_transform(
tgt_atoms = src_atoms tgt_atoms = src_atoms
else: else:
src_atoms = src_atoms[mask, :] src_atoms = src_atoms[mask, :]
tgt_atoms = tgt_atoms[mask, :] tgt_atoms = tgt_atoms.to('cuda:0')[mask, :]
src_center = src_atoms.mean(-2, keepdim=True) src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True) tgt_center = tgt_atoms.mean(-2, keepdim=True)
r = kabsch_rotation(src_atoms,tgt_atoms) r = kabsch_rotation(src_atoms,tgt_atoms)
tgt_center,src_center = tgt_center.to('cpu'),src_center.to('cpu') # load to cpu memory just in case tgt_center,src_center = tgt_center.to('cuda:0'),src_center.to('cuda:0')
x = tgt_center - src_center @ r x = tgt_center - src_center @ r.to('cuda:0')
return r, x return r, x
...@@ -1752,9 +1753,11 @@ def compute_rmsd( ...@@ -1752,9 +1753,11 @@ def compute_rmsd(
eps: float = 1e-6, eps: float = 1e-6,
) -> torch.Tensor: ) -> torch.Tensor:
# shape check # shape check
true_atom_pos = true_atom_pos.to('cuda:0')
pred_atom_pos = pred_atom_pos.to('cuda:0')
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False) sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
if atom_mask is not None: if atom_mask is not None:
sq_diff = sq_diff[atom_mask] sq_diff = sq_diff.to('cpu')[atom_mask.to('cpu')] # somehow it causes overflow on cuda so moved to cpu
msd = torch.mean(sq_diff) msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=1e8) msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps) return torch.sqrt(msd + eps)
...@@ -1842,7 +1845,7 @@ def greedy_align( ...@@ -1842,7 +1845,7 @@ def greedy_align(
cropped_pos = true_ca_poses[j] cropped_pos = true_ca_poses[j]
mask = true_ca_masks[j][cur_residue_index] mask = true_ca_masks[j][cur_residue_index]
rmsd = compute_rmsd( rmsd = compute_rmsd(
cropped_pos, cur_pred_pos, (cur_pred_mask.to('cpu') * mask.to('cpu')).bool() cropped_pos, cur_pred_pos, (cur_pred_mask.to('cuda:0') * mask.to('cuda:0')).bool()
) )
if rmsd < best_rmsd: if rmsd < best_rmsd:
best_rmsd = rmsd best_rmsd = rmsd
...@@ -1901,7 +1904,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1901,7 +1904,7 @@ class AlphaFoldLoss(nn.Module):
out["violation"] = find_structural_violations( out["violation"] = find_structural_violations(
batch, batch,
out["sm"]["positions"][-1], out["sm"]["positions"][-1],
**self.config.violation, **self.config.loss.violation,
) )
if "renamed_atom14_gt_positions" not in out.keys(): if "renamed_atom14_gt_positions" not in out.keys():
...@@ -2047,10 +2050,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2047,10 +2050,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
r, x = get_optimal_transform( r, x = get_optimal_transform(
anchor_true_pos, anchor_true_pos,
anchor_pred_pos, anchor_pred_pos,
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(), (anchor_true_mask.to('cuda:0') * anchor_pred_mask.to('cuda:0')).bool(),
) )
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms aligned_true_ca_poses = [ca.to('cuda:0') @ r.to('cuda:0') + x.to('cuda:0') for ca in true_ca_poses] # apply transforms
align = greedy_align( align = greedy_align(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
...@@ -2087,8 +2090,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2087,8 +2090,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# then permutate ground truth chains before calculating the loss # then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels) permutated_labels = self.multi_chain_perm_align(out,features,labels)
logger.info("finished multi-chain permutation") logger.info("finished multi-chain permutation")
# features.update(permutated_labels) features.update(permutated_labels)
# self.loss(out,features) self.loss(out,features)
return permutated_labels return permutated_labels
## TODO next need to check how the ground truth label is used ## TODO next need to check how the ground truth label is used
# in loss calculation. # in loss calculation.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment