dataset.py 17.6 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
import itertools
import random
from collections import Counter
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import torch
from ase.data import chemical_symbols
from sklearn.linear_model import Ridge

import sevenn._keys as KEY
import sevenn.util as util


class AtomGraphDataset:
    """
    Deprecated

    class representing dataset of AtomGraphData
    the dataset is handled as dict, {label: data}
    if given data is List, it stores data as {KEY_DEFAULT: data}

    cutoff is for metadata of the graphs not used for some calc
    Every data expected to have one unique cutoff
    No validity or check of the condition is done inside the object

    attribute:
        dataset (Dict[str, List]): key is data label(str), value is list of data
        user_labels (List[str]): list of user labels same as dataset.keys()
        meta (Dict, Optional): metadata of dataset
    for now, metadata 'might' have following keys:
        KEY.CUTOFF (float), KEY.CHEMICAL_SPECIES (Dict)
    """

    DATA_KEY_X = (
        KEY.NODE_FEATURE
    )  # atomic_number > one_hot_idx > one_hot_vector
    DATA_KEY_ENERGY = KEY.ENERGY
    DATA_KEY_FORCE = KEY.FORCE
    KEY_DEFAULT = KEY.LABEL_NONE

    def __init__(
        self,
        dataset: Union[Dict[str, List], List],
        cutoff: float,
        metadata: Optional[Dict] = None,
        x_is_one_hot_idx: bool = False,
    ):
        """
        Default constructor of AtomGraphDataset
        Args:
            dataset (Union[Dict[str, List], List]: dataset as dict or pure list
            metadata (Dict, Optional): metadata of data
            cutoff (float): cutoff radius of graphs inside the dataset
            x_is_one_hot_idx (bool): if True, x is one_hot_idx, else 'Z'

        'x' (node feature) of dataset can have 3 states, atomic_numbers,
        one_hot_idx, or one_hot_vector.

        atomic_numbers is general but cannot directly used for input
        one_hot_idx is can be input of the model but requires 'type_map'
        """
        self.cutoff = cutoff
        self.x_is_one_hot_idx = x_is_one_hot_idx
        if metadata is None:
            metadata = {KEY.CUTOFF: cutoff}
        self.meta = metadata
        if type(dataset) is list:
            self.dataset = {self.KEY_DEFAULT: dataset}
        else:
            self.dataset = dataset
        self.user_labels = list(self.dataset.keys())
        # group_by_key here? or not?

    def rewrite_labels_to_data(self):
        """
        Based on self.dataset dict's keys
        write data[KEY.USER_LABEL] to correspond to dict's keys
        Most of times, it is already correctly written
        But required to rewrite if someone rearrange dataset by their own way
        """
        for label, data_list in self.dataset.items():
            for data in data_list:
                data[KEY.USER_LABEL] = label

    def group_by_key(self, data_key: str = KEY.USER_LABEL):
        """
        group dataset list by given key and save it as dict
        and change in-place
        Args:
            data_key (str): data key to group by

        original use is USER_LABEL, but it can be used for other keys
        if someone established it from data[KEY.INFO]
        """
        data_list = self.to_list()
        self.dataset = {}
        for datum in data_list:
            key = datum[data_key]
            if key not in self.dataset:
                self.dataset[key] = []
            self.dataset[key].append(datum)
        self.user_labels = list(self.dataset.keys())

    def separate_info(self, data_key: str = KEY.INFO):
        """
        Separate info from data and save it as list of dict
        to make it compatible with torch_geometric and later training
        """
        data_list = self.to_list()
        info_list = []
        for datum in data_list:
            if data_key in datum is False:
                continue
            info_list.append(datum[data_key])
            del datum[data_key]  # It does change the self.dataset
            datum[data_key] = len(info_list) - 1
        self.info_list = info_list

        return (data_list, info_list)

    def get_species(self):
        """
        You can also use get_natoms and extract keys from there instead of this
        (And it is more efficient)
        get chemical species of dataset
        return list of SORTED chemical species (as str)
        """
        if hasattr(self, 'type_map'):
            natoms = self.get_natoms(self.type_map)
        else:
            natoms = self.get_natoms()
        species = set()
        for natom_dct in natoms.values():
            species.update(natom_dct.keys())
        species = sorted(list(species))
        return species

    def get_modalities(self):
        modalities = set()
        for data_list in self.dataset.values():
            datum = data_list[0].to_dict()
            if KEY.DATA_MODALITY in datum.keys():
                modalities.add(datum[KEY.DATA_MODALITY])
            else:
                return []
        return list(modalities)

    def write_modal_attr(
        self, modal_type_mapper: dict, write_modal_type: bool = False
    ):
        num_modalities = len(modal_type_mapper)
        for data_list in self.dataset.values():
            for data in data_list:
                tmp_tensor = torch.zeros(num_modalities)
                if data[KEY.DATA_MODALITY] != 'common':
                    modal_idx = modal_type_mapper[data[KEY.DATA_MODALITY]]
                    tmp_tensor[modal_idx] = 1.0
                    if write_modal_type:
                        data[KEY.MODAL_TYPE] = modal_idx
                data[KEY.MODAL_ATTR] = tmp_tensor

    def get_dict_sort_by_modality(self):
        dict_sort_by_modality = {}
        for data_list in self.dataset.values():
            try:
                modal_key = data_list[0].to_dict()[KEY.DATA_MODALITY]
            except:  # Dataset is not modal
                raise ValueError('This dataset has no modality.')

            if modal_key not in dict_sort_by_modality.keys():
                dict_sort_by_modality[modal_key] = []
            dict_sort_by_modality[modal_key].extend(data_list)

        return dict_sort_by_modality

    def len(self):
        if (
            len(self.dataset.keys()) == 1
            and list(self.dataset.keys())[0] == AtomGraphDataset.KEY_DEFAULT
        ):
            return len(self.dataset[AtomGraphDataset.KEY_DEFAULT])
        else:
            return {k: len(v) for k, v in self.dataset.items()}

    def get(self, idx: int, key: Optional[str] = None):
        if key is None:
            key = self.KEY_DEFAULT
        return self.dataset[key][idx]

    def items(self):
        return self.dataset.items()

    def to_dict(self):
        dct_dataset = {}
        for label, data_list in self.dataset.items():
            dct_dataset[label] = [datum.to_dict() for datum in data_list]
        self.dataset = dct_dataset
        return self

    def x_to_one_hot_idx(self, type_map: Dict[int, int]):
        """
        type_map is dict of {atomic_number: one_hot_idx}
        after this process, the dataset has dependency on type_map
        or chemical species user want to consider
        """
        assert self.x_is_one_hot_idx is False
        for data_list in self.dataset.values():
            for datum in data_list:
                datum[self.DATA_KEY_X] = torch.LongTensor(
                    [type_map[z.item()] for z in datum[self.DATA_KEY_X]]
                )
        self.type_map = type_map
        self.x_is_one_hot_idx = True

    def toggle_requires_grad_of_data(
        self, key: str, requires_grad_value: bool
    ):
        """
        set requires_grad of specific key of data(pos, edge_vec, ...)
        """
        for data_list in self.dataset.values():
            for datum in data_list:
                datum[key].requires_grad_(requires_grad_value)

    def divide_dataset(
        self,
        ratio: float,
        constant_ratio_btw_labels: bool = True,
        ignore_test: bool = True
    ):
        """
        divide dataset into 1-2*ratio : ratio : ratio
        return divided AtomGraphDataset
        returned value lost its dict key and became {KEY_DEFAULT: datalist}
        but KEY.USER_LABEL of each data is preserved
        """

        def divide(ratio: float, data_list: List, ignore_test=True):
            if ratio > 0.5:
                raise ValueError('Ratio must not exceed 0.5')
            data_len = len(data_list)
            random.shuffle(data_list)
            n_validation = int(data_len * ratio)
            if n_validation == 0:
                raise ValueError(
                    '# of validation set is 0, increase your dataset'
                )

            if ignore_test:
                test_list = []
                n_train = data_len - n_validation
                train_list = data_list[0:n_train]
                valid_list = data_list[n_train:]
            else:
                n_train = data_len - 2 * n_validation
                train_list = data_list[0:n_train]
                valid_list = data_list[n_train : n_train + n_validation]
                test_list = data_list[n_train + n_validation : data_len]
            return train_list, valid_list, test_list

        lists = ([], [], [])  # train, valid, test
        if constant_ratio_btw_labels:
            for data_list in self.dataset.values():
                for store, divided in zip(lists, divide(ratio, data_list)):
                    store.extend(divided)
        else:
            lists = divide(ratio, self.to_list())

        dbs = tuple(
            AtomGraphDataset(data, self.cutoff, self.meta) for data in lists
        )
        for db in dbs:
            db.group_by_key()
        return dbs

    def to_list(self):
        return list(itertools.chain(*self.dataset.values()))

    def get_natoms(self, type_map: Optional[Dict[int, int]] = None):
        """
        if x_is_one_hot_idx, type_map is required
        type_map: Z->one_hot_index(node_feature)
        return Dict{label: {symbol, natom}]}
        """
        assert not (self.x_is_one_hot_idx is True and type_map is None)
        natoms = {}
        for label, data in self.dataset.items():
            natoms[label] = Counter()
            for datum in data:
                if self.x_is_one_hot_idx and type_map is not None:
                    Zs = util.onehot_to_chem(datum[self.DATA_KEY_X], type_map)
                else:
                    Zs = [
                        chemical_symbols[z]
                        for z in datum[self.DATA_KEY_X].tolist()
                    ]
                cnt = Counter(Zs)
                natoms[label] += cnt
            natoms[label] = dict(natoms[label])
        return natoms

    def get_per_atom_mean(self, key: str, key_num_atoms: str = KEY.NUM_ATOMS):
        """
        return per_atom mean of given data key
        """
        eng_list = torch.Tensor(
            [x[key] / x[key_num_atoms] for x in self.to_list()]
        )
        return float(torch.mean(eng_list))

    def get_per_atom_energy_mean(self):
        """
        alias for get_per_atom_mean(KEY.ENERGY)
        """
        return self.get_per_atom_mean(self.DATA_KEY_ENERGY)

    def get_species_ref_energy_by_linear_comb(self, num_chem_species: int):
        """
        Total energy as y, composition as c_i,
        solve linear regression of y = c_i*X
        sklearn LinearRegression as solver

        x should be one-hot-indexed
        give num_chem_species if possible
        """
        assert self.x_is_one_hot_idx is True
        data_list = self.to_list()

        c = torch.zeros((len(data_list), num_chem_species))
        for idx, datum in enumerate(data_list):
            c[idx] = torch.bincount(
                datum[self.DATA_KEY_X], minlength=num_chem_species
            )
        y = torch.Tensor([x[self.DATA_KEY_ENERGY] for x in data_list])
        c = c.numpy()
        y = y.numpy()

        # tweak to fine tune training from many-element to small element
        zero_indices = np.all(c == 0, axis=0)
        c_reduced = c[:, ~zero_indices]
        full_coeff = np.zeros(num_chem_species)
        coef_reduced = (
            Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_
        )
        full_coeff[~zero_indices] = coef_reduced

        return full_coeff

    def get_force_rms(self):
        force_list = []
        for x in self.to_list():
            force_list.extend(
                x[self.DATA_KEY_FORCE]
                .reshape(
                    -1,
                )
                .tolist()
            )
        force_list = torch.Tensor(force_list)
        return float(torch.sqrt(torch.mean(torch.pow(force_list, 2))))

    def get_species_wise_force_rms(self, num_chem_species: int):
        """
        Return force rms for each species
        Averaged by each components (x, y, z)
        """
        assert self.x_is_one_hot_idx is True
        data_list = self.to_list()

        atomx = torch.concat([d[self.DATA_KEY_X] for d in data_list])
        force = torch.concat([d[self.DATA_KEY_FORCE] for d in data_list])

        index = atomx.repeat_interleave(3, 0).reshape(force.shape)
        rms = torch.zeros(
            (num_chem_species, 3),
            dtype=force.dtype,
            device=force.device
        )
        rms.scatter_reduce_(
            0, index, force.square(),
            reduce='mean', include_self=False
        )
        return torch.sqrt(rms.mean(dim=1))

    def get_avg_num_neigh(self):
        n_neigh = []
        for _, data_list in self.dataset.items():
            for data in data_list:
                n_neigh.extend(
                    np.unique(data[KEY.EDGE_IDX][0], return_counts=True)[1]
                )

        avg_num_neigh = np.average(n_neigh)
        return avg_num_neigh

    def get_statistics(self, key: str):
        """
        return dict of statistics of given key (energy, force, stress)
        key of dict is its label and _total for total statistics
        value of dict is dict of statistics (mean, std, median, max, min)
        """

        def _get_statistic_dict(tensor_list):
            data_list = torch.cat(
                [
                    tensor.reshape(
                        -1,
                    )
                    for tensor in tensor_list
                ]
            )
            data_list = data_list[~torch.isnan(data_list)]
            return {
                'mean': float(torch.mean(data_list)),
                'std': float(torch.std(data_list)),
                'median': float(torch.median(data_list)),
                'max': (
                    torch.nan
                    if data_list.numel() == 0
                    else float(torch.max(data_list))
                ),
                'min': (
                    torch.nan
                    if data_list.numel() == 0
                    else float(torch.min(data_list))
                ),
            }

        res = {}
        for label, values in self.dataset.items():
            # flatten list of torch.Tensor (values)
            tensor_list = [x[key] for x in values]
            res[label] = _get_statistic_dict(tensor_list)
        tensor_list = [x[key] for x in self.to_list()]
        res['Total'] = _get_statistic_dict(tensor_list)
        return res

    def augment(self, dataset, validator: Optional[Callable] = None):
        """check meta compatibility here
        dataset(AtomGraphDataset): data to augment
        validator(Callable, Optional): function(self, dataset) -> bool

        if validator is None, by default it checks
        whether cutoff & chemical_species are same before augment

        check consistent data type, float, double, long integer etc
        """

        def default_validator(db1, db2):
            cut_consis = db1.cutoff == db2.cutoff
            # compare unordered lists
            x_is_not_onehot = (not db1.x_is_one_hot_idx) and (
                not db2.x_is_one_hot_idx
            )
            return cut_consis and x_is_not_onehot

        if validator is None:
            validator = default_validator
        if not validator(self, dataset):
            raise ValueError('given datasets are not compatible check cutoffs')
        for key, val in dataset.items():
            if key in self.dataset:
                self.dataset[key].extend(val)
            else:
                self.dataset.update({key: val})
        self.user_labels = list(self.dataset.keys())

    def unify_dtypes(
        self,
        float_dtype: torch.dtype = torch.float32,
        int_dtype: torch.dtype = torch.int64
    ):
        data_list = self.to_list()
        for datum in data_list:
            for k, v in list(datum.items()):
                datum[k] = util.dtype_correct(v, float_dtype, int_dtype)

    def delete_data_key(self, key: str):
        for data in self.to_list():
            del data[key]

    # TODO: this by_label is not straightforward
    def save(self, path: str, by_label: bool = False):
        if by_label:
            for label, data in self.dataset.items():
                torch.save(
                    AtomGraphDataset(
                        {label: data}, self.cutoff, metadata=self.meta
                    ),
                    f'{path}/{label}.sevenn_data',
                )
        else:
            if path.endswith('.sevenn_data') is False:
                path += '.sevenn_data'
            torch.save(self, path)