Unverified Commit 1a27b07f authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Take advantage of torch.where broadcasting (#447)

parent 5da2adbc
......@@ -94,7 +94,7 @@ def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
reciprocal_cell = cell.inverse().t()
inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
num_repeats = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
num_repeats = torch.where(pbc, num_repeats, num_repeats.new_zeros(()))
r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device)
r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment