"examples/vscode:/vscode.git/clone" did not exist on "dddebc0df2a6952a508cd1a127c7aff0bc60d934"
Commit beccaa4b authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fix the aatype indexing error. now move all tensors to cpu and run calculations first

parent 80b32a3a
......@@ -2114,15 +2114,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
features = tensor_tree_map(lambda t: t[..., -1], features)
# then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels)
permutated_labels.pop('aatype')
logger.info("finished multi-chain permutation")
features.update(permutated_labels)
move_to_gpu = lambda t: (t.to('cuda:0'))
features = tensor_tree_map(move_to_gpu,features)
print(f"after moving features:",torch.cuda.memory_allocated(0))
print(f"features is {type(features)} and out is {type(out)}")
# out = tensor_tree_map(move_to_gpu,out)
for k,v in out.items():
out[k] = v.to('cuda')
move_to_cpu = lambda t: (t.to('cpu'))
features = tensor_tree_map(move_to_cpu,features)
self.loss(out,features)
return permutated_labels
## TODO next need to check how the ground truth label is used
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment