Commit beccaa4b authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fix the aatype indexing error. now move all tensors to cpu and run calculations first

parent 80b32a3a
...@@ -2114,15 +2114,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2114,15 +2114,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
features = tensor_tree_map(lambda t: t[..., -1], features) features = tensor_tree_map(lambda t: t[..., -1], features)
# then permutate ground truth chains before calculating the loss # then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels) permutated_labels = self.multi_chain_perm_align(out,features,labels)
permutated_labels.pop('aatype')
logger.info("finished multi-chain permutation") logger.info("finished multi-chain permutation")
features.update(permutated_labels) features.update(permutated_labels)
move_to_gpu = lambda t: (t.to('cuda:0')) move_to_cpu = lambda t: (t.to('cpu'))
features = tensor_tree_map(move_to_gpu,features) features = tensor_tree_map(move_to_cpu,features)
print(f"after moving features:",torch.cuda.memory_allocated(0))
print(f"features is {type(features)} and out is {type(out)}")
# out = tensor_tree_map(move_to_gpu,out)
for k,v in out.items():
out[k] = v.to('cuda')
self.loss(out,features) self.loss(out,features)
return permutated_labels return permutated_labels
## TODO next need to check how the ground truth label is used ## TODO next need to check how the ground truth label is used
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment