Unverified Commit fec59ac3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Make use of PyTorch's type promotion logic to reduce memory access (#446)



PyTorch has supported type promotion since the last PTDC. We shouldn't manually convert the dtype now.
Co-authored-by: default avatarFarhad Ramezanghorbani <farhadrgh@users.noreply.github.com>
parent 1ddb371d
...@@ -135,7 +135,7 @@ def map2central(cell, coordinates, pbc): ...@@ -135,7 +135,7 @@ def map2central(cell, coordinates, pbc):
inv_cell = torch.inverse(cell) inv_cell = torch.inverse(cell)
coordinates_cell = torch.matmul(coordinates, inv_cell) coordinates_cell = torch.matmul(coordinates, inv_cell)
# Step 2: wrap cell coordinates into [0, 1) # Step 2: wrap cell coordinates into [0, 1)
coordinates_cell -= coordinates_cell.floor() * pbc.to(coordinates_cell.dtype) coordinates_cell -= coordinates_cell.floor() * pbc
# Step 3: convert from cell coordinates back to standard cartesian # Step 3: convert from cell coordinates back to standard cartesian
# coordinate # coordinate
return torch.matmul(coordinates_cell, cell) return torch.matmul(coordinates_cell, cell)
...@@ -191,8 +191,8 @@ class EnergyShifter(torch.nn.Module): ...@@ -191,8 +191,8 @@ class EnergyShifter(torch.nn.Module):
"""(species, molecular energies)->(species, molecular energies + sae) """(species, molecular energies)->(species, molecular energies + sae)
""" """
species, energies = species_energies species, energies = species_energies
sae = self.sae(species).to(energies.device) sae = self.sae(species)
return SpeciesEnergies(species, energies.to(sae.dtype) + sae) return SpeciesEnergies(species, energies + sae)
class ChemicalSymbolsToInts: class ChemicalSymbolsToInts:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment