graph_build.py 3.03 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
from typing import List, Optional

from sevenn.logger import Logger
from sevenn.train.dataset import AtomGraphDataset
from sevenn.util import unique_filepath


def build_sevennet_graph_dataset(
    source: List[str],
    cutoff: float,
    num_cores: int,
    out: str,
    filename: str,
    metadata: Optional[dict] = None,
    **fmt_kwargs,
):
    from sevenn.train.graph_dataset import SevenNetGraphDataset

    log = Logger()
    if metadata is None:
        metadata = {}

    log.timer_start('graph_build')
    db = SevenNetGraphDataset(
        cutoff=cutoff,
        root=out,
        files=source,
        processed_name=filename,
        process_num_cores=num_cores,
        **fmt_kwargs,
    )
    log.timer_end('graph_build', 'graph build time')
    log.writeline(f'Graph saved: {db.processed_paths[0]}')

    log.bar()
    for k, v in metadata.items():
        log.format_k_v(k, v, write=True)
    log.bar()

    log.writeline('Distribution:')
    log.statistic_write(db.statistics)
    log.format_k_v('# atoms (node)', db.natoms, write=True)
    log.format_k_v('# structures (graph)', len(db), write=True)


def dataset_finalize(dataset, metadata, out):
    """
    Deprecated
    """
    natoms = dataset.get_natoms()
    species = dataset.get_species()
    metadata = {
        **metadata,
        'natoms': natoms,
        'species': species,
    }
    dataset.meta = metadata

    if os.path.isdir(out):
        out = os.path.join(out, 'graph_built.sevenn_data')
    elif out.endswith('.sevenn_data') is False:
        out = out + '.sevenn_data'
    out = unique_filepath(out)

    log = Logger()
    log.writeline('The metadata of the dataset is...')
    for k, v in metadata.items():
        log.format_k_v(k, v, write=True)
    dataset.save(out)
    log.writeline(f'dataset is saved to {out}')

    return dataset


def build_script(
    source: List[str],
    cutoff: float,
    num_cores: int,
    out: str,
    metadata: Optional[dict] = None,
    **fmt_kwargs,
):
    """
    Deprecated
    """
    from sevenn.train.dataload import file_to_dataset, match_reader

    if metadata is None:
        metadata = {}
    log = Logger()

    dataset = AtomGraphDataset({}, cutoff)
    common_args = {
        'cutoff': cutoff,
        'cores': num_cores,
        'label': 'graph_build',
    }
    log.timer_start('graph_build')
    for path in source:
        if os.path.isdir(path):
            continue
        log.writeline(f'Read: {path}')
        basename = os.path.basename(path)
        if 'structure_list' in basename:
            fmt = 'structure_list'
        else:
            fmt = 'ase'
        reader, rmeta = match_reader(fmt, **fmt_kwargs)
        metadata.update(**rmeta)
        dataset.augment(
            file_to_dataset(
                file=path,
                reader=reader,
                **common_args,
            )
        )
    log.timer_end('graph_build', 'graph build time')
    dataset_finalize(dataset, metadata, out)