Commit c715b138 authored by Jennifer Wei's avatar Jennifer Wei
Browse files

Merge remote-tracking branch 'refs/remotes/jnwei/pl_upgrades' into pl_upgrades

parents e3e09c46 76fb7ce6
......@@ -35,10 +35,10 @@ def _superimpose_np(reference, coords):
def _superimpose_single(reference, coords):
reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
def superimpose(reference, coords, mask):
......
......@@ -682,9 +682,9 @@ if __name__ == "__main__":
trainer_group.add_argument(
"--reload_dataloaders_every_n_epochs", type=int, default=1,
)
trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1,
help="Accumulate gradients over k batches before next optimizer step.")
trainer_group.add_argument(
"--accumulate_grad_batches", type=int, default=1,
help="Accumulate gradients over k batches before next optimizer step.")
args = parser.parse_args()
......@@ -700,5 +700,4 @@ if __name__ == "__main__":
raise ValueError(
"Choose between loading pretrained Jax-weights and a checkpoint-path")
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment