Unverified Commit 6fe30cc9 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Port to torch.bool data type where it should be (#278)

* Port to torch.bool data type where it should be

* more

* more

* fix snoop

* hopefully fixed

* fix ignite

* remove snoop

* flake8
parent b0fbdca2
......@@ -326,10 +326,10 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
predicted_energies = []
num_atoms = []
for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).sum(dim=1))
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
optimizer.zero_grad()
......
......@@ -235,7 +235,7 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
chunk_species = chunk['species']
chunk_coordinates = chunk['coordinates']
chunk_true_forces = chunk['forces']
chunk_num_atoms = (chunk_species >= 0).sum(dim=1).to(true_energies.dtype)
chunk_num_atoms = (chunk_species >= 0).to(true_energies.dtype).sum(dim=1)
num_atoms.append(chunk_num_atoms)
# We must set `chunk_coordinates` to make it requires grad, so
......
......@@ -9,7 +9,6 @@ import itertools
import ase
import ase.io
import math
import numpy
path = os.path.dirname(os.path.realpath(__file__))
N = 97
......@@ -191,7 +190,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
dtype=torch.double, requires_grad=True)
cell = torch.eye(3, dtype=torch.double) * 2
species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long)
pbc = torch.ones(3, dtype=torch.uint8)
pbc = torch.ones(3, dtype=torch.bool)
_, aev = self.aev_computer((species, coordinates, cell, pbc))
......@@ -203,7 +202,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
def testPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
pbc = torch.ones(3, dtype=torch.bool)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
xyz1 = torch.tensor([0.1, 0.1, 0.1])
......@@ -225,7 +224,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
def testPBCSurfaceSeeEachOther(self):
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
pbc = torch.ones(3, dtype=torch.bool)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
species = torch.tensor([[0, 0]])
......@@ -242,7 +241,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
def testPBCEdgesSeeEachOther(self):
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
pbc = torch.ones(3, dtype=torch.bool)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
species = torch.tensor([[0, 0]])
......@@ -264,7 +263,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
species = torch.tensor([[0, 0]])
cell = ase.geometry.cellpar_to_cell([10, 10, 10 * math.sqrt(2), 90, 45, 90])
cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
pbc = torch.ones(3, dtype=torch.uint8)
pbc = torch.ones(3, dtype=torch.bool)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
xyz1 = torch.tensor([0.1, 0.1, 0.05], dtype=torch.double)
......@@ -289,7 +288,7 @@ class TestAEVOnBoundary(unittest.TestCase):
[-0.1, -0.1, 1.0],
[-1.0, -1.0, -1.0]]], dtype=torch.double)
self.species = torch.tensor([[1, 0, 0, 0]])
self.pbc = torch.ones(3, dtype=torch.uint8)
self.pbc = torch.ones(3, dtype=torch.bool)
self.v1, self.v2, self.v3 = self.cell
self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3)
ani1x = torchani.models.ANI1x()
......@@ -329,7 +328,7 @@ class TestAEVOnBenzenePBC(unittest.TestCase):
filename = os.path.join(path, '../tools/generate-unit-test-expect/others/Benzene.cif')
benzene = ase.io.read(filename)
self.cell = torch.tensor(benzene.get_cell(complete=True)).float()
self.pbc = torch.tensor(benzene.get_pbc().astype(numpy.uint8), dtype=torch.uint8)
self.pbc = torch.tensor(benzene.get_pbc(), dtype=torch.bool)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
self.species = species_to_tensor(benzene.get_chemical_symbols()).unsqueeze(0)
self.coordinates = torch.tensor(benzene.get_positions()).unsqueeze(0).float()
......
......@@ -258,8 +258,8 @@ def triple_by_molecule(atom_index1, atom_index2):
local_index2 = rev_indices[sorted_local_index2]
# compute mapping between representation of central-other to pair
sign1 = ((local_index1 < n) * 2).to(torch.long) - 1
sign2 = ((local_index2 < n) * 2).to(torch.long) - 1
sign1 = ((local_index1 < n).to(torch.long) * 2) - 1
sign2 = ((local_index2 < n).to(torch.long) * 2) - 1
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
......@@ -367,7 +367,7 @@ class AEVComputer(torch.nn.Module):
# These values are used when cell and pbc switch are not given.
cutoff = max(self.Rcr, self.Rca)
default_cell = torch.eye(3, dtype=self.EtaR.dtype, device=self.EtaR.device)
default_pbc = torch.zeros(3, dtype=torch.uint8, device=self.EtaR.device)
default_pbc = torch.zeros(3, dtype=torch.bool, device=self.EtaR.device)
default_shifts = compute_shifts(default_cell, default_pbc, cutoff)
self.register_buffer('default_cell', default_cell)
self.register_buffer('default_shifts', default_shifts)
......
......@@ -12,7 +12,6 @@ from . import utils
import ase.calculators.calculator
import ase.units
import copy
import numpy
class Calculator(ase.calculators.calculator.Calculator):
......@@ -65,9 +64,9 @@ class Calculator(ase.calculators.calculator.Calculator):
super(Calculator, self).calculate(atoms, properties, system_changes)
cell = torch.tensor(self.atoms.get_cell(complete=True),
dtype=self.dtype, device=self.device)
pbc = torch.tensor(self.atoms.get_pbc().astype(numpy.uint8), dtype=torch.uint8,
pbc = torch.tensor(self.atoms.get_pbc(), dtype=torch.bool,
device=self.device)
pbc_enabled = bool(pbc.any().item())
pbc_enabled = pbc.any().item()
species = self.species_to_tensor(self.atoms.get_chemical_symbols()).to(self.device)
species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions())
......
......@@ -310,7 +310,7 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
if rm_outlier:
transformed_energies = properties_['energies']
num_atoms = (atomic_properties_['species'] >= 0).sum(dim=1).to(transformed_energies.dtype)
num_atoms = (atomic_properties_['species'] >= 0).to(transformed_energies.dtype).sum(dim=1)
scaled_diff = transformed_energies / num_atoms.sqrt()
mean = transformed_energies.mean()
......
......@@ -62,8 +62,8 @@ class PerAtomDictLoss(DictLoss):
def forward(self, input_, other):
loss = self.loss(input_[self.key], other[self.key])
num_atoms = (input_['species'] >= 0).sum(dim=1)
loss /= num_atoms.to(loss.dtype).to(loss.device)
num_atoms = (input_['species'] >= 0).to(loss.dtype).to(loss.device).sum(dim=1)
loss /= num_atoms
n = loss.numel()
return loss.sum() / n
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment