Unverified Commit 50fb6cdb authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Modify training related APIs to batch the whole data by padding (#63)

parent b7cab4f1
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -14,7 +14,6 @@ class TestEnergies(unittest.TestCase): ...@@ -14,7 +14,6 @@ class TestEnergies(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 5e-5 self.tolerance = 5e-5
aev_computer = torchani.AEVComputer() aev_computer = torchani.AEVComputer()
self.prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species) nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species) shift_energy = torchani.EnergyShifter(aev_computer.species)
self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy) self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy)
...@@ -24,7 +23,6 @@ class TestEnergies(unittest.TestCase): ...@@ -24,7 +23,6 @@ class TestEnergies(unittest.TestCase):
datafile = os.path.join(path, 'test_data/{}'.format(i)) datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, _, _, energies, _ = pickle.load(f) coordinates, species, _, _, energies, _ = pickle.load(f)
species, coordinates = self.prepare((species, coordinates))
_, energies_ = self.model((species, coordinates)) _, energies_ = self.model((species, coordinates))
max_diff = (energies - energies_).abs().max().item() max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance) self.assertLess(max_diff, self.tolerance)
...@@ -36,8 +34,7 @@ class TestEnergies(unittest.TestCase): ...@@ -36,8 +34,7 @@ class TestEnergies(unittest.TestCase):
datafile = os.path.join(path, 'test_data/{}'.format(i)) datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, _, _, e, _ = pickle.load(f) coordinates, species, _, _, e, _ = pickle.load(f)
species_coordinates.append( species_coordinates.append((species, coordinates))
self.prepare((species, coordinates)))
energies.append(e) energies.append(e)
species, coordinates = torchani.padding.pad_and_batch( species, coordinates = torchani.padding.pad_and_batch(
species_coordinates) species_coordinates)
......
import torch
import torchani
import unittest
import random
class TestEnergyShifter(unittest.TestCase):
def setUp(self):
self.tol = 1e-5
self.species = torchani.AEVComputer().species
self.prepare = torchani.PrepareInput(self.species)
self.shift_energy = torchani.EnergyShifter(self.species)
def testSAEMatch(self):
species_coordinates = []
saes = []
for _ in range(10):
k = random.choice(range(5, 30))
species = random.choices(self.species, k=k)
coordinates = torch.empty(1, k, 3)
species_coordinates.append(self.prepare((species, coordinates)))
e = self.shift_energy.sae_from_list(species)
saes.append(e)
species, _ = torchani.padding.pad_and_batch(species_coordinates)
saes_ = self.shift_energy.sae_from_tensor(species)
saes = torch.tensor(saes, dtype=saes_.dtype, device=saes_.device)
self.assertLess((saes - saes_).abs().max(), self.tol)
if __name__ == '__main__':
unittest.main()
...@@ -19,14 +19,13 @@ class TestEnsemble(unittest.TestCase): ...@@ -19,14 +19,13 @@ class TestEnsemble(unittest.TestCase):
n = torchani.buildin_ensemble n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix prefix = torchani.buildin_model_prefix
aev = torchani.AEVComputer() aev = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev.species)
ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True) ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True)
ensemble = torch.nn.Sequential(prepare, aev, ensemble) ensemble = torch.nn.Sequential(aev, ensemble)
models = [torchani.models. models = [torchani.models.
NeuroChemNNP(aev.species, ensemble=False, NeuroChemNNP(aev.species, ensemble=False,
from_=prefix + '{}/networks/'.format(i)) from_=prefix + '{}/networks/'.format(i))
for i in range(n)] for i in range(n)]
models = [torch.nn.Sequential(prepare, aev, m) for m in models] models = [torch.nn.Sequential(aev, m) for m in models]
_, energy1 = ensemble((species, coordinates)) _, energy1 = ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0] force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
......
...@@ -13,10 +13,8 @@ class TestForce(unittest.TestCase): ...@@ -13,10 +13,8 @@ class TestForce(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 1e-5 self.tolerance = 1e-5
aev_computer = torchani.AEVComputer() aev_computer = torchani.AEVComputer()
self.prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species) nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(self.prepare, aev_computer, nnp) self.model = torch.nn.Sequential(aev_computer, nnp)
self.prepared2e = torch.nn.Sequential(aev_computer, nnp)
def testIsomers(self): def testIsomers(self):
for i in range(N): for i in range(N):
...@@ -37,13 +35,12 @@ class TestForce(unittest.TestCase): ...@@ -37,13 +35,12 @@ class TestForce(unittest.TestCase):
datafile = os.path.join(path, 'test_data/{}'.format(i)) datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, _, _, _, forces = pickle.load(f) coordinates, species, _, _, _, forces = pickle.load(f)
species, coordinates = self.prepare((species, coordinates))
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates = torch.tensor(coordinates, requires_grad=True)
species_coordinates.append((species, coordinates)) species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces)) coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.padding.pad_and_batch( species, coordinates = torchani.padding.pad_and_batch(
species_coordinates) species_coordinates)
_, energies = self.prepared2e((species, coordinates)) _, energies = self.model((species, coordinates))
energies = energies.sum() energies = energies.sum()
for coordinates, forces in coordinates_forces: for coordinates, forces in coordinates_forces:
derivative = torch.autograd.grad(energies, coordinates, derivative = torch.autograd.grad(energies, coordinates,
......
import sys import os
import unittest
if sys.version_info.major >= 3: import torch
import os from ignite.engine import create_supervised_trainer, \
import unittest create_supervised_evaluator, Events
import torch import torchani
from ignite.engine import create_supervised_trainer, \ import torchani.training
create_supervised_evaluator, Events
import torchani path = os.path.dirname(os.path.realpath(__file__))
import torchani.data path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
batchsize = 4
path = os.path.dirname(os.path.realpath(__file__)) threshold = 1e-5
path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
chunksize = 4
threshold = 1e-5 class TestIgnite(unittest.TestCase):
class TestIgnite(unittest.TestCase): def testIgnite(self):
aev_computer = torchani.AEVComputer()
def testIgnite(self): nnp = torchani.models.NeuroChemNNP(aev_computer.species)
aev_computer = torchani.AEVComputer() shift_energy = torchani.EnergyShifter(aev_computer.species)
prepare = torchani.PrepareInput(aev_computer.species) ds = torchani.training.BatchedANIDataset(
nnp = torchani.models.NeuroChemNNP(aev_computer.species) path, aev_computer.species, batchsize,
shift_energy = torchani.EnergyShifter(aev_computer.species) transform=[shift_energy.subtract_from_dataset])
ds = torchani.data.ANIDataset( ds = torch.utils.data.Subset(ds, [0])
path, chunksize,
transform=[shift_energy.subtract_from_dataset]) class Flatten(torch.nn.Module):
ds = torch.utils.data.Subset(ds, [0]) def forward(self, x):
loader = torchani.data.dataloader(ds, 1) return x[0], x[1].flatten()
class Flatten(torch.nn.Module): model = torch.nn.Sequential(aev_computer, nnp, Flatten())
def forward(self, x): container = torchani.training.Container({'energies': model})
return x[0], x[1].flatten() optimizer = torch.optim.Adam(container.parameters())
loss = torchani.training.TransformedLoss(
model = torch.nn.Sequential(prepare, aev_computer, nnp, Flatten()) torchani.training.MSELoss('energies'),
container = torchani.ignite.Container({'energies': model}) lambda x: torch.exp(x) - 1)
optimizer = torch.optim.Adam(container.parameters()) trainer = create_supervised_trainer(
loss = torchani.ignite.TransformedLoss( container, optimizer, loss)
torchani.ignite.MSELoss('energies'), evaluator = create_supervised_evaluator(container, metrics={
lambda x: torch.exp(x) - 1) 'RMSE': torchani.training.RMSEMetric('energies')
trainer = create_supervised_trainer( })
container, optimizer, loss)
evaluator = create_supervised_evaluator(container, metrics={ @trainer.on(Events.COMPLETED)
'RMSE': torchani.ignite.RMSEMetric('energies') def completes(trainer):
}) evaluator.run(ds)
metrics = evaluator.state.metrics
@trainer.on(Events.COMPLETED) self.assertLess(metrics['RMSE'], threshold)
def completes(trainer): self.assertLess(trainer.state.output, threshold)
evaluator.run(loader)
metrics = evaluator.state.metrics trainer.run(ds, max_epochs=1000)
self.assertLess(metrics['RMSE'], threshold)
self.assertLess(trainer.state.output, threshold)
if __name__ == '__main__':
trainer.run(loader, max_epochs=1000) unittest.main()
if __name__ == '__main__':
unittest.main()
...@@ -8,7 +8,7 @@ import torchani ...@@ -8,7 +8,7 @@ import torchani
import pickle import pickle
from torchani import buildin_const_file, buildin_sae_file, \ from torchani import buildin_const_file, buildin_sae_file, \
buildin_network_dir buildin_network_dir
import torchani.pyanitools import torchani.training.pyanitools
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
conv_au_ev = 27.21138505 conv_au_ev = 27.21138505
...@@ -58,19 +58,24 @@ class NeuroChem (torchani.aev.AEVComputer): ...@@ -58,19 +58,24 @@ class NeuroChem (torchani.aev.AEVComputer):
energies, forces energies, forces
aev = torchani.AEVComputer()
ncaev = NeuroChem().to(torch.device('cpu')) ncaev = NeuroChem().to(torch.device('cpu'))
mol_count = 0 mol_count = 0
species_indices = {aev.species[i]: i for i in range(len(aev.species))}
for i in [1, 2, 3, 4]: for i in [1, 2, 3, 4]:
data_file = os.path.join( data_file = os.path.join(
path, '../dataset/ani_gdb_s0{}.h5'.format(i)) path, '../dataset/ani_gdb_s0{}.h5'.format(i))
adl = torchani.pyanitools.anidataloader(data_file) adl = torchani.training.pyanitools.anidataloader(data_file)
for data in adl: for data in adl:
coordinates = data['coordinates'][:10, :] coordinates = data['coordinates'][:10, :]
coordinates = torch.from_numpy(coordinates).type(ncaev.EtaR.dtype) coordinates = torch.from_numpy(coordinates).type(ncaev.EtaR.dtype)
species = data['species'] species = torch.tensor([species_indices[i] for i in data['species']],
dtype=torch.long, device=torch.device('cpu')) \
.expand(10, -1)
smiles = ''.join(data['smiles']) smiles = ''.join(data['smiles'])
radial, angular, energies, forces = ncaev(coordinates, species) radial, angular, energies, forces = ncaev(coordinates, data['species'])
pickleobj = (coordinates, species, radial, angular, energies, forces) pickleobj = (coordinates, species, radial, angular, energies, forces)
dumpfile = os.path.join( dumpfile = os.path.join(
path, '../tests/test_data/{}'.format(mol_count)) path, '../tests/test_data/{}'.format(mol_count))
......
from .energyshifter import EnergyShifter from .energyshifter import EnergyShifter
from . import models from . import models
from . import data from . import training
from . import ignite
from . import padding from . import padding
from .aev import AEVComputer, PrepareInput from .aev import AEVComputer
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \ from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble buildin_model_prefix, buildin_ensemble
__all__ = ['PrepareInput', 'AEVComputer', 'EnergyShifter', __all__ = ['PrepareInput', 'AEVComputer', 'EnergyShifter',
'models', 'data', 'padding', 'ignite', 'models', 'training', 'padding', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir', 'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble'] 'buildin_model_prefix', 'buildin_ensemble']
...@@ -107,40 +107,6 @@ class AEVComputerBase(BenchmarkedModule): ...@@ -107,40 +107,6 @@ class AEVComputerBase(BenchmarkedModule):
raise NotImplementedError('subclass must override this method') raise NotImplementedError('subclass must override this method')
class PrepareInput(torch.nn.Module):
def __init__(self, species):
super(PrepareInput, self).__init__()
self.species = species
def species_to_tensor(self, species, device):
"""Convert species list into a long tensor.
Parameters
----------
species : list
List of string for the species of each atoms.
device : torch.device
The device to store tensor
Returns
-------
torch.Tensor
Long tensor for the species, where a value k means the species is
the same as self.species[k].
"""
indices = {self.species[i]: i for i in range(len(self.species))}
values = [indices[i] for i in species]
return torch.tensor(values, dtype=torch.long, device=device)
def forward(self, species_coordinates):
species, coordinates = species_coordinates
conformations = coordinates.shape[0]
species = self.species_to_tensor(species, coordinates.device)
species = species.expand(conformations, -1)
return species, coordinates
def _cutoff_cosine(distances, cutoff): def _cutoff_cosine(distances, cutoff):
"""Compute the elementwise cutoff cosine function """Compute the elementwise cutoff cosine function
...@@ -319,8 +285,11 @@ class AEVComputer(AEVComputerBase): ...@@ -319,8 +285,11 @@ class AEVComputer(AEVComputerBase):
"""Shape (conformations, atoms, atoms) storing Rij distances""" """Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask = (species == -1).unsqueeze(1) padding_mask = (species == -1).unsqueeze(1)
distances = torch.where(padding_mask, torch.tensor(math.inf), distances = torch.where(
distances) padding_mask,
torch.tensor(math.inf, dtype=self.EtaR.dtype,
device=self.EtaR.device),
distances)
distances, indices = distances.sort(-1) distances, indices = distances.sort(-1)
......
...@@ -22,21 +22,18 @@ class EnergyShifter(torch.nn.Module): ...@@ -22,21 +22,18 @@ class EnergyShifter(torch.nn.Module):
torch.tensor(self_energies_tensor, torch.tensor(self_energies_tensor,
dtype=torch.double)) dtype=torch.double))
def sae_from_list(self, species): def sae(self, species):
energies = [self.self_energies[i] for i in species]
return sum(energies)
def sae_from_tensor(self, species):
self_energies = self.self_energies_tensor[species] self_energies = self.self_energies_tensor[species]
self_energies[species == -1] = 0 self_energies[species == -1] = 0
return self_energies.sum(dim=1) return self_energies.sum(dim=1)
def subtract_from_dataset(self, data): def subtract_from_dataset(self, species, coordinates, properties):
sae = self.sae_from_list(data['species']) dtype = properties['energies'].dtype
data['energies'] -= sae device = properties['energies'].device
return data properties['energies'] -= self.sae(species).to(dtype).to(device)
return species, coordinates, properties
def forward(self, species_energies): def forward(self, species_energies):
species, energies = species_energies species, energies = species_energies
sae = self.sae_from_tensor(species).to(energies.dtype) sae = self.sae(species).to(energies.dtype).to(energies.device)
return species, energies + sae return species, energies + sae
...@@ -27,3 +27,10 @@ def present_species(species): ...@@ -27,3 +27,10 @@ def present_species(species):
if present_species[0].item() == -1: if present_species[0].item() == -1:
present_species = present_species[1:] present_species = present_species[1:]
return present_species return present_species
def strip_redundant_padding(species, coordinates):
non_padding = (species >= 0).any(dim=0)
species = species.masked_select(non_padding, dim=1)
coordinates = coordinates.masked_select(non_padding, dim=1)
return species, coordinates
from .container import Container from .container import Container
from .data import BatchedANIDataset, load_or_create
from .loss_metrics import DictLoss, DictMetric, MSELoss, RMSEMetric, \ from .loss_metrics import DictLoss, DictMetric, MSELoss, RMSEMetric, \
TransformedLoss TransformedLoss
from . import pyanitools
__all__ = ['Container', 'DictLoss', 'DictMetric', 'MSELoss', 'RMSEMetric', __all__ = ['Container', 'BatchedANIDataset', 'load_or_create', 'DictLoss',
'TransformedLoss'] 'DictMetric', 'MSELoss', 'RMSEMetric', 'TransformedLoss',
'pyanitools']
import torch import torch
from ..data import collate
class Container(torch.nn.Module): class Container(torch.nn.Module):
...@@ -10,13 +9,13 @@ class Container(torch.nn.Module): ...@@ -10,13 +9,13 @@ class Container(torch.nn.Module):
for i in models: for i in models:
setattr(self, 'model_' + i, models[i]) setattr(self, 'model_' + i, models[i])
def forward(self, batch): def forward(self, species_coordinates):
all_results = [] species, coordinates = species_coordinates
for i in zip(batch['species'], batch['coordinates']): results = {
results = {} 'species': species,
for k in self.keys: 'coordinates': coordinates,
model = getattr(self, 'model_' + k) }
_, results[k] = model(i) for k in self.keys:
all_results.append(results) model = getattr(self, 'model_' + k)
batch.update(collate(all_results)) _, results[k] = model((species, coordinates))
return batch return results
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset
from os.path import join, isfile, isdir from os.path import join, isfile, isdir
import os import os
from .pyanitools import anidataloader from .pyanitools import anidataloader
import torch import torch
import torch.utils.data as data import torch.utils.data as data
import pickle import pickle
import collections.abc from .. import padding
class ANIDataset(Dataset): class BatchedANIDataset(Dataset):
def __init__(self, path, chunk_size, shuffle=True, properties=['energies'], def __init__(self, path, species, batch_size, shuffle=True,
transform=(), dtype=torch.get_default_dtype(), properties=['energies'], transform=(),
device=torch.device('cpu')): dtype=torch.get_default_dtype(), device=torch.device('cpu')):
super(ANIDataset, self).__init__() super(BatchedANIDataset, self).__init__()
self.path = path self.path = path
self.chunks_size = chunk_size self.species = species
self.species_indices = {
self.species[i]: i for i in range(len(self.species))}
self.batch_size = batch_size
self.shuffle = shuffle self.shuffle = shuffle
self.properties = properties self.properties = properties
self.dtype = dtype self.dtype = dtype
...@@ -33,51 +36,68 @@ class ANIDataset(Dataset): ...@@ -33,51 +36,68 @@ class ANIDataset(Dataset):
else: else:
raise ValueError('Bad path') raise ValueError('Bad path')
# generate chunks # load full dataset
chunks = [] species_coordinates = []
properties = {k: [] for k in self.properties}
for f in files: for f in files:
for m in anidataloader(f): for m in anidataloader(f):
full = {
'coordinates': torch.from_numpy(m['coordinates'])
.type(dtype).to(device)
}
conformations = full['coordinates'].shape[0]
for i in properties:
full[i] = torch.from_numpy(m[i]).type(dtype).to(device)
species = m['species'] species = m['species']
if shuffle: indices = [self.species_indices[i] for i in species]
indices = torch.randperm(conformations, device=device) species = torch.tensor(indices, dtype=torch.long,
else: device=device)
indices = torch.arange(conformations, dtype=torch.int64, coordinates = torch.from_numpy(m['coordinates']) \
device=device) .type(dtype).to(device)
num_chunks = (conformations + chunk_size - 1) // chunk_size species_coordinates.append((species, coordinates))
for i in range(num_chunks): for i in properties:
chunk_start = i * chunk_size properties[i].append(torch.from_numpy(m[i])
chunk_end = min(chunk_start + chunk_size, conformations) .type(dtype).to(device))
chunk_indices = indices[chunk_start:chunk_end] species, coordinates = padding.pad_and_batch(species_coordinates)
chunk = {} for i in properties:
for j in full: properties[i] = torch.cat(properties[i])
chunk[j] = full[j].index_select(0, chunk_indices)
chunk['species'] = species # shuffle if required
for t in transform: conformations = coordinates.shape[0]
chunk = t(chunk) if shuffle:
chunks.append(chunk) indices = torch.randperm(conformations, device=device)
self.chunks = chunks species = species.index_select(0, indices)
coordinates = coordinates.index_select(0, indices)
for i in properties:
properties[i] = properties[i].index_select(0, indices)
# do transformations on data
for t in transform:
species, coordinates, properties = t(species, coordinates,
properties)
# split into minibatches, and strip reduncant padding
batches = []
num_batches = (conformations + batch_size - 1) // batch_size
for i in range(num_batches):
start = i * batch_size
end = min((i + 1) * batch_size, conformations)
species_batch = species[start:end, ...]
coordinates_batch = coordinates[start:end, ...]
properties_batch = {
k: properties[k][start:end, ...] for k in properties
}
batches.append(((species_batch, coordinates_batch),
properties_batch))
self.batches = batches
def __getitem__(self, idx): def __getitem__(self, idx):
chunk = self.chunks[idx] return self.batches[idx]
input_chunk = {k: chunk[k] for k in ('coordinates', 'species')}
return input_chunk, chunk
def __len__(self): def __len__(self):
return len(self.chunks) return len(self.batches)
def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs): def load_or_create(checkpoint, batch_size, species, dataset_path,
*args, **kwargs):
"""Generate a 80-10-10 split of the dataset, and checkpoint """Generate a 80-10-10 split of the dataset, and checkpoint
the resulting dataset""" the resulting dataset"""
if not os.path.isfile(checkpoint): if not os.path.isfile(checkpoint):
full_dataset = ANIDataset(dataset_path, chunk_size, *args, **kwargs) full_dataset = BatchedANIDataset(dataset_path, species, batch_size,
*args, **kwargs)
training_size = int(len(full_dataset) * 0.8) training_size = int(len(full_dataset) * 0.8)
validation_size = int(len(full_dataset) * 0.1) validation_size = int(len(full_dataset) * 0.1)
testing_size = len(full_dataset) - training_size - validation_size testing_size = len(full_dataset) - training_size - validation_size
...@@ -90,23 +110,3 @@ def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs): ...@@ -90,23 +110,3 @@ def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs):
with open(checkpoint, 'rb') as f: with open(checkpoint, 'rb') as f:
training, validation, testing = pickle.load(f) training, validation, testing = pickle.load(f)
return training, validation, testing return training, validation, testing
def collate(batch):
no_collate = ['coordinates', 'species']
if isinstance(batch[0], torch.Tensor):
return torch.cat(batch)
elif isinstance(batch[0], collections.abc.Mapping):
return {key: ((lambda x: x) if key in no_collate else collate)
([d[key] for d in batch])
for key in batch[0]}
elif isinstance(batch[0], collections.abc.Sequence):
transposed = zip(*batch)
return [collate(samples) for samples in transposed]
else:
raise ValueError('Unexpected element type: {}'.format(type(batch[0])))
def dataloader(dataset, batch_chunks, shuffle=True, **kwargs):
return DataLoader(dataset, batch_chunks, shuffle,
collate_fn=collate, **kwargs)
...@@ -17,17 +17,10 @@ class DictLoss(_Loss): ...@@ -17,17 +17,10 @@ class DictLoss(_Loss):
class _PerAtomDictLoss(DictLoss): class _PerAtomDictLoss(DictLoss):
@staticmethod
def num_atoms(input):
ret = []
for s, c in zip(input['species'], input['coordinates']):
ret.append(torch.full((c.shape[0],), len(s),
dtype=c.dtype, device=c.device))
return torch.cat(ret)
def forward(self, input, other): def forward(self, input, other):
loss = self.loss(input[self.key], other[self.key]) loss = self.loss(input[self.key], other[self.key])
loss /= self.num_atoms(input) num_atoms = (input['species'] >= 0).sum(dim=1)
loss /= num_atoms.to(loss.dtype).to(loss.device)
n = loss.numel() n = loss.numel()
return loss.sum() / n return loss.sum() / n
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment