Unverified Commit 7c9402d0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

preprocess data on CPU to save cuda memory (#64)

parent 50fb6cdb
......@@ -16,7 +16,7 @@ parser.add_argument('-d', '--device',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--batch_size',
help='Number of conformations of each batch',
default=1024, type=int)
default=256, type=int)
parser = parser.parse_args()
# set up benchmark
......
......@@ -23,6 +23,7 @@ class BatchedANIDataset(Dataset):
self.properties = properties
self.dtype = dtype
self.device = device
device = torch.device('cpu')
# get name of files storing data
files = []
......@@ -85,7 +86,13 @@ class BatchedANIDataset(Dataset):
self.batches = batches
def __getitem__(self, idx):
return self.batches[idx]
(species, coordinates), properties = self.batches[idx]
species = species.to(self.device)
coordinates = coordinates.to(self.device)
properties = {
k: properties[k].to(self.device) for k in properties
}
return (species, coordinates), properties
def __len__(self):
return len(self.batches)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment