Commit 987711c7 authored by Richard Xue's avatar Richard Xue Committed by Gao, Xiang
Browse files

fix new dataset reshape bug (#405)

parent abc8f7f8
......@@ -554,7 +554,7 @@ def collate_fn(data, chunk_threshold, properties_info):
padding_value=-1)
batch_coordinates = torch.nn.utils.rnn.pad_sequence(batch_coordinates,
batch_first=True,
padding_value=0)
padding_value=np.inf)
# sort - time: 0.7s
atoms = torch.sum(~(batch_species == -1), dim=-1, dtype=torch.int32)
......@@ -574,7 +574,7 @@ def collate_fn(data, chunk_threshold, properties_info):
# truncate redundant padding - time: 1.3s
chunks_batch_species = trunc_pad(list(chunks_batch_species), padding_value=-1)
chunks_batch_coordinates = trunc_pad(list(chunks_batch_coordinates))
chunks_batch_coordinates = trunc_pad(list(chunks_batch_coordinates), padding_value=np.inf)
for i, c in enumerate(chunks_batch_coordinates):
chunks_batch_coordinates[i] = c.reshape(c.shape[0], -1, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment