Commit 80f0d617 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

solved cuda error just for now by moving the 2 tensors to cpu

parent 3ab9da6e
...@@ -38,6 +38,8 @@ import random ...@@ -38,6 +38,8 @@ import random
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
import logging import logging
import procrustes import procrustes
from openfold.utils.tensor_utils import tensor_tree_map
import gc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def softmax_cross_entropy(logits, labels): def softmax_cross_entropy(logits, labels):
...@@ -842,6 +844,7 @@ def between_residue_bond_loss( ...@@ -842,6 +844,7 @@ def between_residue_bond_loss(
] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[ ] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
1 1
] ]
c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2) c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
c_n_loss_per_residue = torch.nn.functional.relu( c_n_loss_per_residue = torch.nn.functional.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev c_n_bond_length_error - tolerance_factor_soft * gt_stddev
...@@ -1741,9 +1744,16 @@ def get_optimal_transform( ...@@ -1741,9 +1744,16 @@ def get_optimal_transform(
src_center = src_atoms.mean(-2, keepdim=True) src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True) tgt_center = tgt_atoms.mean(-2, keepdim=True)
r = kabsch_rotation(src_atoms,tgt_atoms) r = kabsch_rotation(src_atoms,tgt_atoms)
del src_atoms,tgt_atoms,
gc.collect()
tgt_center,src_center = tgt_center.to('cuda:0'),src_center.to('cuda:0') tgt_center,src_center = tgt_center.to('cuda:0'),src_center.to('cuda:0')
x = tgt_center - src_center @ r.to('cuda:0') x = tgt_center.to('cpu') - src_center.to('cpu') @ r.to('cpu')
return r, x
del tgt_center,src_center,mask
gc.collect()
return r, x.to('cuda')
def compute_rmsd( def compute_rmsd(
...@@ -1756,6 +1766,9 @@ def compute_rmsd( ...@@ -1756,6 +1766,9 @@ def compute_rmsd(
true_atom_pos = true_atom_pos.to('cuda:0') true_atom_pos = true_atom_pos.to('cuda:0')
pred_atom_pos = pred_atom_pos.to('cuda:0') pred_atom_pos = pred_atom_pos.to('cuda:0')
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False) sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos
del pred_atom_pos
gc.collect()
if atom_mask is not None: if atom_mask is not None:
sq_diff = sq_diff.to('cpu')[atom_mask.to('cpu')] # somehow it causes overflow on cuda so moved to cpu sq_diff = sq_diff.to('cpu')[atom_mask.to('cpu')] # somehow it causes overflow on cuda so moved to cpu
msd = torch.mean(sq_diff) msd = torch.mean(sq_diff)
...@@ -1830,7 +1843,7 @@ def greedy_align( ...@@ -1830,7 +1843,7 @@ def greedy_align(
used[i] = True used[i] = True
continue continue
cur_entity_ids = batch["entity_id"][asym_mask][0] cur_entity_ids = batch["entity_id"][asym_mask][0]
best_rmsd = 1e20 best_rmsd = torch.inf
best_idx = None best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)] cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)] cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
...@@ -1847,7 +1860,8 @@ def greedy_align( ...@@ -1847,7 +1860,8 @@ def greedy_align(
rmsd = compute_rmsd( rmsd = compute_rmsd(
cropped_pos, cur_pred_pos, (cur_pred_mask.to('cuda:0') * mask.to('cuda:0')).bool() cropped_pos, cur_pred_pos, (cur_pred_mask.to('cuda:0') * mask.to('cuda:0')).bool()
) )
if rmsd < best_rmsd: print(f"rmsd is {rmsd}")
if rmsd < best_rmsd:
best_rmsd = rmsd best_rmsd = rmsd
best_idx = j best_idx = j
assert best_idx is not None assert best_idx is not None
...@@ -2047,13 +2061,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2047,13 +2061,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
anchor_pred_pos = pred_ca_pos[asym_mask] anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx] anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
anchor_pred_mask = pred_ca_mask[asym_mask] anchor_pred_mask = pred_ca_mask[asym_mask]
input_mask = (anchor_true_mask.to('cuda:0') * anchor_pred_mask.to('cuda:0')).bool()
r, x = get_optimal_transform( r, x = get_optimal_transform(
anchor_true_pos, anchor_true_pos,
anchor_pred_pos, anchor_pred_pos,mask=input_mask
(anchor_true_mask.to('cuda:0') * anchor_pred_mask.to('cuda:0')).bool(),
) )
del input_mask # just to save memory
aligned_true_ca_poses = [ca.to('cuda:0') @ r.to('cuda:0') + x.to('cuda:0') for ca in true_ca_poses] # apply transforms del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
align = greedy_align( align = greedy_align(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
...@@ -2064,6 +2082,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2064,6 +2082,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
aligned_true_ca_poses, aligned_true_ca_poses,
true_ca_masks, true_ca_masks,
) )
del aligned_true_ca_poses
del r,x
gc.collect()
merged_labels = merge_labels( merged_labels = merge_labels(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
...@@ -2091,6 +2114,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2091,6 +2114,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
permutated_labels = self.multi_chain_perm_align(out,features,labels) permutated_labels = self.multi_chain_perm_align(out,features,labels)
logger.info("finished multi-chain permutation") logger.info("finished multi-chain permutation")
features.update(permutated_labels) features.update(permutated_labels)
move_to_gpu = lambda t: (t.to('cuda:0'))
features = tensor_tree_map(move_to_gpu,features)
print(f"after moving features:",torch.cuda.memory_allocated(0))
# out = tensor_tree_map(move_to_gpu,out)
self.loss(out,features) self.loss(out,features)
return permutated_labels return permutated_labels
## TODO next need to check how the ground truth label is used ## TODO next need to check how the ground truth label is used
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment