Unverified Commit fec59ac3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Make use of PyTorch's type promotion logic to reduce memory access (#446)



PyTorch has supported type promotion since the last PTDC. We shouldn't manually convert the dtype now.
Co-authored-by: default avatarFarhad Ramezanghorbani <farhadrgh@users.noreply.github.com>
parent 1ddb371d
......@@ -135,7 +135,7 @@ def map2central(cell, coordinates, pbc):
inv_cell = torch.inverse(cell)
coordinates_cell = torch.matmul(coordinates, inv_cell)
# Step 2: wrap cell coordinates into [0, 1)
coordinates_cell -= coordinates_cell.floor() * pbc.to(coordinates_cell.dtype)
coordinates_cell -= coordinates_cell.floor() * pbc
# Step 3: convert from cell coordinates back to standard cartesian
# coordinate
return torch.matmul(coordinates_cell, cell)
......@@ -191,8 +191,8 @@ class EnergyShifter(torch.nn.Module):
"""(species, molecular energies)->(species, molecular energies + sae)
"""
species, energies = species_energies
sae = self.sae(species).to(energies.device)
return SpeciesEnergies(species, energies.to(sae.dtype) + sae)
sae = self.sae(species)
return SpeciesEnergies(species, energies + sae)
class ChemicalSymbolsToInts:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment