Commit 8b9c3d1a authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Support 0 molecule subsets when loading dataset (#257)

* Support 0 molecule subsets when loading dataset

* fix

* fix
parent 909935a5
......@@ -48,17 +48,20 @@ def pad_atomic_properties(atomic_properties, padding_values=defaultdict(lambda:
max_atoms = max(x[anykey].shape[1] for x in atomic_properties)
padded = {k: [] for k in keys}
for p in atomic_properties:
num_molecules = max(v.shape[0] for v in p.values())
num_molecules = 1
for v in p.values():
assert num_molecules in {1, v.shape[0]}, 'Number of molecules in different atomic properties mismatch'
if v.shape[0] != 1:
num_molecules = v.shape[0]
for k, v in p.items():
shape = list(v.shape)
padatoms = max_atoms - shape[1]
shape[1] = padatoms
padding = v.new_full(shape, padding_values[k])
v = torch.cat([v, padding], dim=1)
if v.shape[0] < num_molecules:
shape = list(v.shape)
shape[0] = num_molecules
v = v.expand(*shape)
shape = list(v.shape)
shape[0] = num_molecules
v = v.expand(*shape)
padded[k].append(v)
return {k: torch.cat(v) for k, v in padded.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment