collate.py 1.26 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import Any, List, Optional, Sequence

from ase.atoms import Atoms
from torch_geometric.loader.dataloader import Collater

from sevenn.atom_graph_data import AtomGraphData

from .dataload import atoms_to_graph


class AtomsToGraphCollater(Collater):

    def __init__(
        self,
        dataset: Sequence[Atoms],
        cutoff: float,
        transfer_info: bool = False,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        y_from_calc: bool = True,
    ):
        # quite original collator's type mismatch with []
        super().__init__([], follow_batch, exclude_keys)
        self.dataset = dataset
        self.cutoff = cutoff
        self.transfer_info = transfer_info
        self.y_from_calc = y_from_calc

    def __call__(self, batch: List[Any]) -> Any:
        # build list of graph
        graph_list = []
        for stct in batch:
            graph = atoms_to_graph(
                stct,
                self.cutoff,
                transfer_info=self.transfer_info,
                y_from_calc=self.y_from_calc,
            )
            graph = AtomGraphData.from_numpy_dict(graph)
            graph_list.append(graph)
        return super().__call__(graph_list)