Unverified Commit b3744935 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

working ignite and dataloader with simple examples (#21)

parent ea718be0
version: '1.0' version: '1.0'
steps: steps:
build-torchani:
BuildTorchANI:
type: build type: build
description: Build TorchANI description: Build TorchANI
image-name: torchani image_name: torchani
dockerfile: Dockerfile dockerfile: Dockerfile
tag: latest tag: latest
unit-tests:
image: '${{build-torchani}}' CodeStyle:
image: '${{BuildTorchANI}}'
commands: commands:
- flake8 - flake8
UnitTests:
image: '${{BuildTorchANI}}'
commands:
- python setup.py test - python setup.py test
# - python2 setup.py test # - python2 setup.py test
Examples:
image: '${{BuildTorchANI}}'
commands:
- python examples/nnp_training.py ./dataset/ani_gdb_s01.h5
- python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
import torch
import torchani
import torchani.data
import math
import timeit
import sys
import pickle
from tensorboardX import SummaryWriter
from tqdm import tqdm
from common import get_or_create_model, Averager, evaluate
import json
chunk_size = 256
batch_chunks = 1024 // chunk_size
with open('data/dataset.dat', 'rb') as f:
training, validation, testing = pickle.load(f)
training_sampler = torchani.data.BatchSampler(
training, chunk_size, batch_chunks)
validation_sampler = torchani.data.BatchSampler(
validation, chunk_size, batch_chunks)
testing_sampler = torchani.data.BatchSampler(
testing, chunk_size, batch_chunks)
training_dataloader = torch.utils.data.DataLoader(
training, batch_sampler=training_sampler,
collate_fn=torchani.data.collate)
validation_dataloader = torch.utils.data.DataLoader(
validation, batch_sampler=validation_sampler,
collate_fn=torchani.data.collate)
testing_dataloader = torch.utils.data.DataLoader(
testing, batch_sampler=testing_sampler,
collate_fn=torchani.data.collate)
writer = SummaryWriter('runs/adam-{}'.format(sys.argv[1]))
checkpoint = 'checkpoint.pt'
model = get_or_create_model(checkpoint)
optimizer = torch.optim.Adam(model.parameters(), **json.loads(sys.argv[1]))
step = 0
epoch = 0
def subset_rmse(subset_dataloader):
a = Averager()
for batch in subset_dataloader:
for molecule_id in batch:
_species = subset_dataloader.dataset.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(model.aev_computer.device)
count, squared_error = evaluate(
model, coordinates, energies, _species)
squared_error = squared_error.item()
a.add(count, squared_error)
mse = a.avg()
rmse = math.sqrt(mse) * 627.509
return rmse
def optimize_step(a):
mse = a.avg()
rmse = math.sqrt(mse.item()) * 627.509
writer.add_scalar('training_rmse_vs_step', rmse, step)
loss = mse if epoch < 10 else 0.5 * torch.exp(2 * mse)
optimizer.zero_grad()
loss.backward()
optimizer.step()
best_validation_rmse = math.inf
best_epoch = 0
start = timeit.default_timer()
while True:
for batch in tqdm(training_dataloader,
desc='epoch {}'.format(epoch),
total=len(training_sampler)):
a = Averager()
for molecule_id in batch:
_species = training.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(model.aev_computer.device)
count, squared_error = evaluate(
model, coordinates, energies, _species)
a.add(count, squared_error / len(_species))
optimize_step(a)
step += 1
validation_rmse = subset_rmse(validation_dataloader)
elapsed = round(timeit.default_timer() - start, 2)
print('Epoch:', epoch, 'time:', elapsed,
'validation rmse:', validation_rmse)
writer.add_scalar('validation_rmse_vs_epoch', validation_rmse, epoch)
writer.add_scalar('epoch_vs_step', epoch, step)
writer.add_scalar('time_vs_epoch', elapsed, epoch)
if validation_rmse < best_validation_rmse:
best_validation_rmse = validation_rmse
best_epoch = epoch
writer.add_scalar('best_validation_rmse_vs_epoch',
best_validation_rmse, best_epoch)
elif epoch - best_epoch > 1000:
print('Stop at best validation rmse:', best_validation_rmse)
break
epoch += 1
testing_rmse = subset_rmse(testing_dataloader)
print('Test rmse:', validation_rmse)
import torch
import torchani
import torchani.data
import math
import timeit
import pickle
from tensorboardX import SummaryWriter
from tqdm import tqdm
from common import get_or_create_model, Averager, evaluate
chunk_size = 256
batch_chunks = 1024 // chunk_size
with open('data/dataset.dat', 'rb') as f:
training, validation, testing = pickle.load(f)
training_sampler = torchani.data.BatchSampler(
training, chunk_size, batch_chunks)
validation_sampler = torchani.data.BatchSampler(
validation, chunk_size, batch_chunks)
testing_sampler = torchani.data.BatchSampler(
testing, chunk_size, batch_chunks)
training_dataloader = torch.utils.data.DataLoader(
training, batch_sampler=training_sampler,
collate_fn=torchani.data.collate)
validation_dataloader = torch.utils.data.DataLoader(
validation, batch_sampler=validation_sampler,
collate_fn=torchani.data.collate)
testing_dataloader = torch.utils.data.DataLoader(
testing, batch_sampler=testing_sampler,
collate_fn=torchani.data.collate)
writer = SummaryWriter()
checkpoint = 'checkpoint.pt'
model = get_or_create_model(checkpoint)
optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)
step = 0
epoch = 0
def subset_rmse(subset_dataloader):
a = Averager()
for batch in subset_dataloader:
for molecule_id in batch:
_species = subset_dataloader.dataset.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(model.aev_computer.device)
count, squared_error = evaluate(coordinates, energies, _species)
squared_error = squared_error.item()
a.add(count, squared_error)
mse = a.avg()
rmse = math.sqrt(mse) * 627.509
return rmse
def optimize_step(a):
mse = a.avg()
rmse = math.sqrt(mse.item()) * 627.509
writer.add_scalar('training_rmse_vs_step', rmse, step)
loss = mse if epoch < 10 else 0.5 * torch.exp(2 * mse)
optimizer.zero_grad()
loss.backward()
optimizer.step()
best_validation_rmse = math.inf
best_epoch = 0
start = timeit.default_timer()
while True:
for batch in tqdm(training_dataloader, desc='epoch {}'.format(epoch),
total=len(training_sampler)):
a = Averager()
for molecule_id in batch:
_species = training.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(model.aev_computer.device)
count, squared_error = evaluate(
model, coordinates, energies, _species)
a.add(count, squared_error / len(_species))
optimize_step(a)
step += 1
validation_rmse = subset_rmse(validation_dataloader)
elapsed = round(timeit.default_timer() - start, 2)
print('Epoch:', epoch, 'time:', elapsed,
'validation rmse:', validation_rmse)
writer.add_scalar('validation_rmse_vs_epoch', validation_rmse, epoch)
writer.add_scalar('epoch_vs_step', epoch, step)
writer.add_scalar('time_vs_epoch', elapsed, epoch)
if validation_rmse < best_validation_rmse:
best_validation_rmse = validation_rmse
best_epoch = epoch
writer.add_scalar('best_validation_rmse_vs_epoch',
best_validation_rmse, best_epoch)
torch.save(model.state_dict(), checkpoint)
elif epoch - best_epoch > 1000:
print('Stop at best validation rmse:', best_validation_rmse)
break
epoch += 1
testing_rmse = subset_rmse(testing_dataloader)
print('Test rmse:', validation_rmse)
import pickle
import torch
hyperparams = [ # (chunk size, batch chunks)
# (64, 4),
(64, 8),
(64, 16),
(64, 32),
(128, 2),
(128, 4),
(128, 8),
(128, 16),
(256, 1),
(256, 2),
(256, 4),
(256, 8),
(512, 1),
(512, 2),
(512, 4),
(1024, 1),
(1024, 2),
(2048, 1),
]
for chunk_size, batch_chunks in hyperparams:
with open('data/avg-{}-{}.dat'.format(chunk_size, batch_chunks),
'rb') as f:
ag, agsqr = pickle.load(f)
variance = torch.sum(agsqr) - torch.sum(ag**2)
stddev = torch.sqrt(variance).item()
print(chunk_size, batch_chunks, stddev)
import sys
import torch
import torchani
import configs
import torchani.data
from tqdm import tqdm
import pickle
from common import get_or_create_model, Averager, evaluate
device = configs.device
if len(sys.argv) >= 2:
device = torch.device(sys.argv[1])
ds = torchani.data.load_dataset(configs.data_path)
model = get_or_create_model('/tmp/model.pt', device=device)
# just to conveniently zero grads
optimizer = torch.optim.Adam(model.parameters())
def grad_or_zero(parameter):
if parameter.grad is not None:
return parameter.grad.reshape(-1)
else:
return torch.zeros_like(parameter.reshape(-1))
def batch_gradient(batch):
a = Averager()
for molecule_id in batch:
_species = ds.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(model.aev_computer.device)
a.add(*evaluate(coordinates, energies, _species))
mse = a.avg()
optimizer.zero_grad()
mse.backward()
grads = [grad_or_zero(p) for p in model.parameters()]
grads = torch.cat(grads)
return grads
def compute(chunk_size, batch_chunks):
sampler = torchani.data.BatchSampler(ds, chunk_size, batch_chunks)
dataloader = torch.utils.data.DataLoader(
ds, batch_sampler=sampler, collate_fn=torchani.data.collate)
model_file = 'data/model.pt'
model.load_state_dict(torch.load(
model_file, map_location=lambda storage, loc: storage))
ag = Averager() # avg(grad)
agsqr = Averager() # avg(grad^2)
for batch in tqdm(dataloader, total=len(sampler)):
g = batch_gradient(batch)
ag.add(1, g)
agsqr.add(1, g**2)
ag = ag.avg()
agsqr = agsqr.avg()
filename = 'data/avg-{}-{}.dat'.format(chunk_size, batch_chunks)
with open(filename, 'wb') as f:
pickle.dump((ag, agsqr), f)
chunk_size = int(sys.argv[2])
batch_chunks = int(sys.argv[3])
compute(chunk_size, batch_chunks)
# for chunk_size, batch_chunks in hyperparams:
# compute(chunk_size, batch_chunks)
import torch
data_path = 'data/ANI-1x_complete.h5'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torchani.data
import pickle
from configs import data_path
chunk_size = 64
dataset = torchani.data.load_dataset(data_path)
chunks = len(torchani.data.BatchSampler(dataset, chunk_size, 1))
print(chunks, 'chunks')
training_size = int(chunks*0.8)
validation_size = int(chunks*0.1)
testing_size = chunks - training_size - validation_size
training, validation, testing = torchani.data.random_split(
dataset, [training_size, validation_size, testing_size], chunk_size)
with open('data/dataset.dat', 'wb') as f:
pickle.dump((training, validation, testing), f)
# flake8: noqa
import torch import torch
import os
import torchani import torchani
device = torch.device('cpu') device = torch.device('cpu')
const_file = '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params' path = os.path.dirname(os.path.realpath(__file__))
sae_file = '../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat' const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params') # noqa: E501
network_dir = '../torchani/resources/ani-1x_dft_x8ens/train' sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat') # noqa: E501
network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501
aev_computer = torchani.SortedAEV(const_file=const_file, device=device) aev_computer = torchani.SortedAEV(const_file=const_file, device=device)
nn = torchani.ModelOnAEV(aev_computer, derivative=True, nn = torchani.models.NeuroChemNNP(aev_computer, derivative=True,
from_nc=network_dir, ensemble=8) from_=network_dir, ensemble=8)
shift_energy = torchani.EnergyShifter(sae_file) shift_energy = torchani.EnergyShifter(sae_file)
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
...@@ -17,7 +18,8 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], ...@@ -17,7 +18,8 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.66518241, -0.84461308, 0.20759389], [-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881], [0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]], [0.66091919, -0.16799635, -0.91037834]]],
dtype=aev_computer.dtype, device=aev_computer.device) dtype=aev_computer.dtype,
device=aev_computer.device)
species = ['C', 'H', 'H', 'H', 'H'] species = ['C', 'H', 'H', 'H', 'H']
energy, derivative = nn(coordinates, species) energy, derivative = nn(coordinates, species)
......
import torchani
import torch import torch
import torchani
import os import os
import configs
class Averager:
def __init__(self):
self.count = 0
self.subtotal = 0
def add(self, count, subtotal):
self.count += count
self.subtotal += subtotal
def avg(self):
return self.subtotal / self.count
def celu(x, alpha): def celu(x, alpha):
...@@ -28,14 +13,10 @@ class AtomicNetwork(torch.nn.Module): ...@@ -28,14 +13,10 @@ class AtomicNetwork(torch.nn.Module):
super(AtomicNetwork, self).__init__() super(AtomicNetwork, self).__init__()
self.aev_computer = aev_computer self.aev_computer = aev_computer
self.output_length = 1 self.output_length = 1
self.layer1 = torch.nn.Linear(384, 128).type( self.layer1 = torch.nn.Linear(384, 128)
aev_computer.dtype).to(aev_computer.device) self.layer2 = torch.nn.Linear(128, 128)
self.layer2 = torch.nn.Linear(128, 128).type( self.layer3 = torch.nn.Linear(128, 64)
aev_computer.dtype).to(aev_computer.device) self.layer4 = torch.nn.Linear(64, 1)
self.layer3 = torch.nn.Linear(128, 64).type(
aev_computer.dtype).to(aev_computer.device)
self.layer4 = torch.nn.Linear(64, 1).type(
aev_computer.dtype).to(aev_computer.device)
def forward(self, aev): def forward(self, aev):
y = aev y = aev
...@@ -49,9 +30,10 @@ class AtomicNetwork(torch.nn.Module): ...@@ -49,9 +30,10 @@ class AtomicNetwork(torch.nn.Module):
return y return y
def get_or_create_model(filename, benchmark=False, device=configs.device): def get_or_create_model(filename, benchmark=False,
device=torchani.default_device):
aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device) aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device)
model = torchani.ModelOnAEV( model = torchani.models.CustomModel(
aev_computer, aev_computer,
reducer=torch.sum, reducer=torch.sum,
benchmark=benchmark, benchmark=benchmark,
...@@ -65,17 +47,4 @@ def get_or_create_model(filename, benchmark=False, device=configs.device): ...@@ -65,17 +47,4 @@ def get_or_create_model(filename, benchmark=False, device=configs.device):
model.load_state_dict(torch.load(filename)) model.load_state_dict(torch.load(filename))
else: else:
torch.save(model.state_dict(), filename) torch.save(model.state_dict(), filename)
return model return model.to(device)
energy_shifter = torchani.EnergyShifter()
loss = torch.nn.MSELoss(size_average=False)
def evaluate(model, coordinates, energies, species):
count = coordinates.shape[0]
pred = model(coordinates, species).squeeze()
pred = energy_shifter.add_sae(pred, species)
squared_error = loss(pred, energies)
return count, squared_error
import sys
import torch
import ignite
import torchani
import model
chunk_size = 256
batch_chunks = 4
dataset_path = sys.argv[1]
dataset_checkpoint = 'dataset-checkpoint.dat'
model_checkpoint = 'checkpoint.pt'
shift_energy = torchani.EnergyShifter()
training, validation, testing = torchani.data.load_or_create(
dataset_checkpoint, dataset_path, chunk_size,
transform=[shift_energy.dataset_subtract_sae])
training = torchani.data.dataloader(training, batch_chunks)
validation = torchani.data.dataloader(validation, batch_chunks)
nnp = model.get_or_create_model(model_checkpoint)
class Flatten(torch.nn.Module):
def __init__(self, model):
super(Flatten, self).__init__()
self.model = model
def forward(self, *input):
return self.model(*input).flatten()
batch_nnp = torchani.models.BatchModel(Flatten(nnp))
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss)
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.energy_rmse_metric
})
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def log_training_loss(trainer):
print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch,
trainer.state.output))
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def log_training_results(trainer):
evaluator.run(training)
metrics = evaluator.state.metrics
print("Training Results - Epoch: {} RMSE: {:.2f}"
.format(trainer.state.epoch, metrics['RMSE']))
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
evaluator.run(validation)
metrics = evaluator.state.metrics
print("Validation Results - Epoch: {} RMSE: {:.2f}"
.format(trainer.state.epoch, metrics['RMSE']))
trainer.run(training, max_epochs=10)
import sys
import torch import torch
import ignite
import torchani import torchani
import torchani.data
import tqdm
import timeit import timeit
import configs import model
import functools
from common import get_or_create_model, Averager, evaluate
ds = torchani.data.load_dataset(configs.data_path) chunk_size = 256
sampler = torchani.data.BatchSampler(ds, 256, 4) batch_chunks = 4
dataloader = torch.utils.data.DataLoader( dataset_path = sys.argv[1]
ds, batch_sampler=sampler, shift_energy = torchani.EnergyShifter()
collate_fn=torchani.data.collate, num_workers=20) dataset = torchani.data.ANIDataset(
model = get_or_create_model('/tmp/model.pt', True) dataset_path, chunk_size,
optimizer = torch.optim.Adam(model.parameters(), amsgrad=True) transform=[shift_energy.dataset_subtract_sae])
dataloader = torchani.data.dataloader(dataset, batch_chunks)
nnp = model.get_or_create_model('/tmp/model.pt', True)
def benchmark(timer, index): class Flatten(torch.nn.Module):
def wrapper(fun):
@functools.wraps(fun)
def wrapped(*args, **kwargs):
start = timeit.default_timer()
ret = fun(*args, **kwargs)
end = timeit.default_timer()
timer[index] += end - start
return ret
return wrapped
return wrapper
def __init__(self, model):
super(Flatten, self).__init__()
self.model = model
timer = {'backward': 0} def forward(self, *input):
return self.model(*input).flatten()
@benchmark(timer, 'backward') batch_nnp = torchani.models.BatchModel(Flatten(nnp))
def optimize_step(a): container = torchani.ignite.Container({'energies': batch_nnp})
mse = a.avg() optimizer = torch.optim.Adam(nnp.parameters())
optimizer.zero_grad()
mse.backward()
optimizer.step()
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss)
start = timeit.default_timer() start = timeit.default_timer()
for batch in tqdm.tqdm(dataloader, total=len(sampler)): trainer.run(dataloader, max_epochs=1)
a = Averager()
for molecule_id in batch:
_species = ds.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(model.aev_computer.device)
a.add(*evaluate(model, coordinates, energies, _species))
optimize_step(a)
elapsed = round(timeit.default_timer() - start, 2) elapsed = round(timeit.default_timer() - start, 2)
print('Radial terms:', model.aev_computer.timers['radial terms']) print('Radial terms:', nnp.aev_computer.timers['radial terms'])
print('Angular terms:', model.aev_computer.timers['angular terms']) print('Angular terms:', nnp.aev_computer.timers['angular terms'])
print('Terms and indices:', model.aev_computer.timers['terms and indices']) print('Terms and indices:', nnp.aev_computer.timers['terms and indices'])
print('Combinations:', model.aev_computer.timers['combinations']) print('Combinations:', nnp.aev_computer.timers['combinations'])
print('Mask R:', model.aev_computer.timers['mask_r']) print('Mask R:', nnp.aev_computer.timers['mask_r'])
print('Mask A:', model.aev_computer.timers['mask_a']) print('Mask A:', nnp.aev_computer.timers['mask_a'])
print('Assemble:', model.aev_computer.timers['assemble']) print('Assemble:', nnp.aev_computer.timers['assemble'])
print('Total AEV:', model.aev_computer.timers['total']) print('Total AEV:', nnp.aev_computer.timers['total'])
print('NN:', model.timers['nn']) print('NN:', nnp.timers['nn'])
print('Total Forward:', model.timers['forward']) print('Total Forward:', nnp.timers['forward'])
print('Total Backward:', timer['backward'])
print('Epoch time:', elapsed) print('Epoch time:', elapsed)
from setuptools import setup from setuptools import setup, find_packages
from sphinx.setup_command import BuildDoc from sphinx.setup_command import BuildDoc
cmdclass = {'build_sphinx': BuildDoc} cmdclass = {'build_sphinx': BuildDoc}
...@@ -9,7 +9,7 @@ setup(name='torchani', ...@@ -9,7 +9,7 @@ setup(name='torchani',
author='Xiang Gao', author='Xiang Gao',
author_email='qasdfgtyuiop@ufl.edu', author_email='qasdfgtyuiop@ufl.edu',
license='MIT', license='MIT',
packages=['torchani'], packages=find_packages(),
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=[
'torch', 'torch',
......
...@@ -9,7 +9,7 @@ if sys.version_info.major >= 3: ...@@ -9,7 +9,7 @@ if sys.version_info.major >= 3:
import itertools import itertools
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, 'dataset') path = os.path.join(path, '../dataset')
chunksize = 32 chunksize = 32
batch_chunks = 32 batch_chunks = 32
dtype = torch.float32 dtype = torch.float32
......
...@@ -6,7 +6,7 @@ if sys.version_info.major >= 3: ...@@ -6,7 +6,7 @@ if sys.version_info.major >= 3:
import torchani.data import torchani.data
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, 'dataset') path = os.path.join(path, '../dataset')
class TestDataset(unittest.TestCase): class TestDataset(unittest.TestCase):
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torchani import torchani
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
N = 97 N = 10
class TestEnsemble(unittest.TestCase): class TestEnsemble(unittest.TestCase):
......
...@@ -10,7 +10,7 @@ if sys.version_info.major >= 3: ...@@ -10,7 +10,7 @@ if sys.version_info.major >= 3:
import torchani.data import torchani.data
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, 'dataset/ani_gdb_s01.h5') path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
chunksize = 32 chunksize = 32
batch_chunks = 32 batch_chunks = 32
dtype = torch.float32 dtype = torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment