"...internal/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "c794fef2f27d141393064665d2774b341d091393"
Unverified Commit 1a27b07f authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Take advantage of torch.where broadcasting (#447)

parent 5da2adbc
...@@ -94,7 +94,7 @@ def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor: ...@@ -94,7 +94,7 @@ def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
reciprocal_cell = cell.inverse().t() reciprocal_cell = cell.inverse().t()
inv_distances = reciprocal_cell.norm(2, -1) inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long) num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
num_repeats = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats)) num_repeats = torch.where(pbc, num_repeats, num_repeats.new_zeros(()))
r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device) r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device)
r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device) r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device) r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment