processing_dataset.py 19.8 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
import os

import torch
import torch.distributed as dist

import sevenn._const as CONST
import sevenn._keys as KEY
from sevenn.logger import Logger
from sevenn.train.dataload import file_to_dataset, match_reader
from sevenn.train.dataset import AtomGraphDataset
from sevenn.util import chemical_species_preprocess, onehot_to_chem


def dataset_load(file: str, config):
    """
    Wrapping of dataload.file_to_dataset to suppert
    graph prebuilt sevenn_data
    """
    log = Logger()
    log.write(f'Loading {file}\n')
    log.timer_start('loading dataset')

    if file.endswith('.sevenn_data'):
        dataset = torch.load(file, map_location='cpu', weights_only=False)
    else:
        reader, _ = match_reader(
            config[KEY.DATA_FORMAT], **config[KEY.DATA_FORMAT_ARGS]
        )
        dataset = file_to_dataset(
            file,
            config[KEY.CUTOFF],
            config[KEY.PREPROCESS_NUM_CORES],
            reader=reader,
            use_modality=config[KEY.USE_MODALITY],
            use_weight=config[KEY.USE_WEIGHT],
        )
    log.format_k_v('loaded dataset size is', dataset.len(), write=True)
    log.timer_end('loading dataset', 'data set loading time')
    return dataset


def calculate_shift_or_scale_from_key(
    train_set: AtomGraphDataset, key_given, n_chem
):
    _expand = True
    use_species_wise_shift_scale = False
    if key_given == 'per_atom_energy_mean':
        shift_or_scale = train_set.get_per_atom_energy_mean()
    elif key_given == 'elemwise_reference_energies':
        shift_or_scale = train_set.get_species_ref_energy_by_linear_comb(n_chem)
        _expand = False
        use_species_wise_shift_scale = True

    elif key_given == 'force_rms':
        shift_or_scale = train_set.get_force_rms()
    elif key_given == 'per_atom_energy_std':
        shift_or_scale = train_set.get_statistics(KEY.PER_ATOM_ENERGY)['Total'][
            'std'
        ]
    elif key_given == 'elemwise_force_rms':
        shift_or_scale = train_set.get_species_wise_force_rms(n_chem)
        _expand = False
        use_species_wise_shift_scale = True

    return shift_or_scale, _expand, use_species_wise_shift_scale


def handle_shift_scale(config, train_set: AtomGraphDataset, checkpoint_given):
    """
    Priority (first comes later to overwrite):
        1. Float given in yaml
        2. Use statistic values of checkpoint == True
        3. Plain options (provided as string)
    """
    log = Logger()
    shift, scale, conv_denominator = None, None, None
    type_map = config[KEY.TYPE_MAP]
    n_chem = len(type_map)
    chem_strs = onehot_to_chem(list(range(n_chem)), type_map)

    log.writeline('\nCalculating statistic values from dataset')

    shift_given = config[KEY.SHIFT]
    scale_given = config[KEY.SCALE]
    _expand_shift = True
    _expand_scale = True
    use_species_wise_shift = False
    use_species_wise_scale = False

    use_modal_wise_shift = config[KEY.USE_MODAL_WISE_SHIFT]
    use_modal_wise_scale = config[KEY.USE_MODAL_WISE_SCALE]

    if shift_given in CONST.IMPLEMENTED_SHIFT:
        shift, _expand_shift, use_species_wise_shift = (
            calculate_shift_or_scale_from_key(train_set, shift_given, n_chem)
        )

    if scale_given in CONST.IMPLEMENTED_SCALE:
        scale, _expand_scale, use_species_wise_scale = (
            calculate_shift_or_scale_from_key(train_set, scale_given, n_chem)
        )

    if use_modal_wise_shift or use_modal_wise_scale:
        atomdata_dict_sort_by_modal = train_set.get_dict_sort_by_modality()
        modal_map = config[KEY.MODAL_MAP]
        n_modal = len(modal_map)
        cutoff = config[KEY.CUTOFF]

        if use_modal_wise_shift:
            shift = torch.zeros((n_modal, n_chem))

        if use_modal_wise_scale:
            scale = torch.zeros((n_modal, n_chem))

        for modal_key, data_list in atomdata_dict_sort_by_modal.items():
            modal_set = AtomGraphDataset(data_list, cutoff, x_is_one_hot_idx=True)

            if use_modal_wise_shift:
                if shift_given == 'elemwise_reference_energies':
                    modal_shift, _expand_shift, use_species_wise_shift = (
                        calculate_shift_or_scale_from_key(
                            modal_set, shift_given, n_chem
                        )
                    )
                    shift[modal_map[modal_key]] = torch.tensor(
                        modal_shift
                    )  # this is np.array
                elif shift_given in CONST.IMPLEMENTED_SHIFT:
                    raise NotImplementedError(
                        'Currently, modal-wise shift implemented for'
                        'species-dependent case only.'
                    )

            if use_modal_wise_scale:
                if scale_given == 'elemwise_force_rms':
                    modal_scale, _expand_scale, use_species_wise_scale = (
                        calculate_shift_or_scale_from_key(
                            modal_set, scale_given, n_chem
                        )
                    )
                    scale[modal_map[modal_key]] = modal_scale
                elif scale_given in CONST.IMPLEMENTED_SCALE:
                    raise NotImplementedError(
                        'Currently, modal-wise scale implemented for'
                        'species-dependent case only.'
                    )

    avg_num_neigh = train_set.get_avg_num_neigh()
    log.format_k_v('Average # of neighbors', f'{avg_num_neigh:.6f}', write=True)

    if config[KEY.CONV_DENOMINATOR] == 'avg_num_neigh':
        conv_denominator = avg_num_neigh
    elif config[KEY.CONV_DENOMINATOR] == 'sqrt_avg_num_neigh':
        conv_denominator = avg_num_neigh ** (0.5)

    if (
        checkpoint_given
        and config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT]
    ):
        log.writeline(
            'Overwrite shift, scale, conv_denominator from model checkpoint'
        )
        # TODO: This needs refactoring
        conv_denominator = config[KEY.CONV_DENOMINATOR + '_cp']
        if not (use_modal_wise_shift or use_modal_wise_scale):
            # Values extracted from checkpoint in processing_continue.py
            if len(list(shift)) > 1:
                use_species_wise_shift = True
                use_species_wise_scale = True
                _expand_shift = _expand_scale = False
            else:
                shift = shift.item()
                scale = scale.item()
        else:
            # Case of modal wise shift scale
            shift_cp = config[KEY.SHIFT + '_cp']
            scale_cp = config[KEY.SCALE + '_cp']
            if not use_modal_wise_shift:
                shift = shift_cp
            if not use_modal_wise_scale:
                scale = scale_cp
            modal_map = config[KEY.MODAL_MAP]
            modal_map_cp = config[KEY.MODAL_MAP + '_cp']

            # Extracting shift, scale for modal in checkpoint model.
            if config[KEY.USE_MODALITY + '_cp']:  # cp model is multimodal
                for modal_key_cp, modal_idx_cp in modal_map_cp.items():
                    modal_idx = modal_map[modal_key_cp]
                    if use_modal_wise_shift:
                        shift[modal_idx] = torch.tensor(shift_cp[modal_idx_cp])
                    if use_modal_wise_scale:
                        scale[modal_idx] = torch.tensor(scale_cp[modal_idx_cp])

            else:  # cp model is single modal
                try:
                    modal_idx = modal_map[config[KEY.DEFAULT_MODAL]]
                except:
                    raise KeyError(
                        f'{config[KEY.DEFAULT_MODAL]} should be one of'
                        f' {modal_map.keys()}'
                    )
                if use_modal_wise_shift:
                    shift[modal_idx] = torch.tensor(shift_cp)
                if use_modal_wise_scale:
                    scale[modal_idx] = torch.tensor(scale_cp)

            if not config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY]:
                # Also overwrite values of new modal to reference value
                # For multimodal, set reference modal with KEY.DEFAULT_MODAL
                shift_ref = shift_cp
                scale_ref = scale_cp
                if config[KEY.USE_MODALITY + '_cp']:
                    try:
                        modal_idx_cp = modal_map_cp[config[KEY.DEFAULT_MODAL]]
                    except:
                        raise KeyError(
                            f'{config[KEY.DEFAULT_MODAL]} should be one of'
                            f' {modal_map_cp.keys()}'
                        )
                    shift_ref = shift_cp[modal_idx_cp]
                    scale_ref = scale_cp[modal_idx_cp]

                for modal_key, modal_idx in modal_map.items():
                    if modal_key not in modal_map_cp.keys():
                        if use_modal_wise_shift:
                            shift[modal_idx] = shift_ref
                        if use_modal_wise_scale:
                            scale[modal_idx] = scale_ref

    # overwrite shift scale anyway if defined in yaml.
    if type(shift_given) in [list, float]:
        log.writeline('Overwrite shift to value(s) given in yaml')
        _expand_shift = isinstance(shift_given, float)
        shift = shift_given
    if type(scale_given) in [list, float]:
        log.writeline('Overwrite scale to value(s) given in yaml')
        _expand_scale = isinstance(scale_given, float)
        scale = scale_given

    if isinstance(config[KEY.CONV_DENOMINATOR], float):
        log.writeline('Overwrite conv_denominator to value given in yaml')
        conv_denominator = config[KEY.CONV_DENOMINATOR]

    if isinstance(conv_denominator, float):
        conv_denominator = [conv_denominator] * config[KEY.NUM_CONVOLUTION]

    use_species_wise_shift_scale = use_species_wise_shift or use_species_wise_scale
    if use_species_wise_shift_scale:
        chem_strs = onehot_to_chem(list(range(n_chem)), type_map)
        if _expand_shift:
            if use_modal_wise_shift:
                shift = torch.full((n_modal, n_chem), shift)
            else:
                shift = [shift] * n_chem
        if _expand_scale:
            if use_modal_wise_scale:
                scale = torch.full((n_modal, n_chem), scale)
            else:
                scale = [scale] * n_chem

        Logger().write('Use element-wise shift, scale\n')
        if use_modal_wise_shift or use_modal_wise_scale:
            for modal_key, modal_idx in modal_map.items():
                Logger().writeline(f'For modal = {modal_key}')
                print_shift = shift[modal_idx] if use_modal_wise_shift else shift
                print_scale = scale[modal_idx] if use_modal_wise_scale else scale
                for cstr, sh, sc in zip(chem_strs, print_shift, print_scale):
                    Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True)
        else:
            for cstr, sh, sc in zip(chem_strs, shift, scale):
                Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True)
    else:
        log.write('Use global shift, scale\n')
        log.format_k_v('shift, scale', f'{shift:.6f}, {scale:.6f}', write=True)

    assert isinstance(conv_denominator, list) and all(
        isinstance(deno, float) for deno in conv_denominator
    )
    log.format_k_v(
        '(1st) conv_denominator is', f'{conv_denominator[0]:.6f}', write=True
    )

    config[KEY.USE_SPECIES_WISE_SHIFT_SCALE] = use_species_wise_shift_scale
    return shift, scale, conv_denominator


# TODO: This is too long
def processing_dataset(config, working_dir):
    log = Logger()
    prefix = f'{os.path.abspath(working_dir)}/'
    is_stress = config[KEY.IS_TRAIN_STRESS]
    checkpoint_given = config[KEY.CONTINUE][KEY.CHECKPOINT] is not False
    cutoff = config[KEY.CUTOFF]

    log.write('\nInitializing dataset...\n')

    dataset = AtomGraphDataset({}, cutoff)
    load_dataset = config[KEY.LOAD_DATASET]
    if type(load_dataset) is str:
        load_dataset = [load_dataset]
    for file in load_dataset:
        dataset.augment(dataset_load(file, config))

    dataset.group_by_key()  # apply labels inside original datapoint
    dataset.unify_dtypes()  # unify dtypes of all data points

    # TODO: I think manual chemical species input is redundant
    chem_in_db = dataset.get_species()
    if config[KEY.CHEMICAL_SPECIES] == 'auto' and not checkpoint_given:
        log.writeline('Auto detect chemical species from dataset')
        config.update(chemical_species_preprocess(chem_in_db))
    elif config[KEY.CHEMICAL_SPECIES] == 'auto' and checkpoint_given:
        pass  # copied from checkpoint in processing_continue.py
    elif config[KEY.CHEMICAL_SPECIES] != 'auto' and not checkpoint_given:
        pass  # processed in parse_input.py
    else:  # config[KEY.CHEMICAL_SPECIES] != "auto" and checkpoint_given
        log.writeline('Ignore chemical species in yaml, use checkpoint')
        # already processed in processing_continue.py

    # basic dataset compatibility check with previous model
    if checkpoint_given:
        chem_from_cp = config[KEY.CHEMICAL_SPECIES]
        if not all(chem in chem_from_cp for chem in chem_in_db):
            raise ValueError('Chemical species in checkpoint is not compatible')

    # check what modalities are used in dataset
    if config[KEY.USE_MODALITY]:
        modalities = dataset.get_modalities()
        num_modalities = len(modalities)
        if num_modalities < 2:
            Logger().writeline('Only one modal is given, ignore modality')
            config.uptate({KEY.USE_MODALITY: False})

        else:
            modal_map_cp = config[KEY.MODAL_MAP + '_cp'] if checkpoint_given else {}
            modal_map = modal_map_cp.copy()
            current_idx = len(modal_map_cp)
            for modal_key in modalities:
                if modal_key not in modal_map.keys():
                    modal_map[modal_key] = current_idx
                    current_idx += 1

            if config[KEY.IS_DDP]:
                # Synchronize modal_map
                torch.cuda.set_device(config[KEY.LOCAL_RANK])
                modal_map_bcast = [modal_map]
                dist.broadcast_object_list(modal_map_bcast, src=0)
                modal_map = modal_map_bcast[0]

            config.update(
                {
                    KEY.NUM_MODALITIES: len(modal_map),
                    KEY.MODAL_MAP: modal_map,
                    KEY.MODAL_LIST: list(modal_map.keys()),
                }
            )

            dataset.write_modal_attr(
                modal_map,
                config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE],
            )

    # --------------- save dataset regardless of train/valid--------------#
    save_dataset = config[KEY.SAVE_DATASET]
    save_by_label = config[KEY.SAVE_BY_LABEL]
    if save_dataset:
        if save_dataset.endswith('.sevenn_data') is False:
            save_dataset += '.sevenn_data'
        if (save_dataset.startswith('.') or save_dataset.startswith('/')) is False:
            save_dataset = prefix + save_dataset  # save_data set is plain file name
        dataset.save(save_dataset)
        log.format_k_v('Dataset saved to', save_dataset, write=True)
        # log.write(f"Loaded full dataset saved to : {save_dataset}\n")
    if save_by_label:
        dataset.save(prefix, by_label=True)
        log.format_k_v('Dataset saved by label', prefix, write=True)
    # --------------------------------------------------------------------#

    # TODO: testset is not used
    ignore_test = not config.get(KEY.USE_TESTSET, False)
    if KEY.LOAD_VALIDSET in config and config[KEY.LOAD_VALIDSET]:
        train_set = dataset
        test_set = AtomGraphDataset([], config[KEY.CUTOFF])

        log.write('Loading validset from load_validset\n')
        valid_set = AtomGraphDataset({}, cutoff)
        for file in config[KEY.LOAD_VALIDSET]:
            valid_set.augment(dataset_load(file, config))
        valid_set.group_by_key()
        valid_set.unify_dtypes()

        # condition: validset labels should be subset of trainset labels
        valid_labels = valid_set.user_labels
        train_labels = train_set.user_labels
        if set(valid_labels).issubset(set(train_labels)) is False:
            valid_set = AtomGraphDataset(valid_set.to_list(), cutoff)
            valid_set.rewrite_labels_to_data()
            train_set = AtomGraphDataset(train_set.to_list(), cutoff)
            train_set.rewrite_labels_to_data()
            Logger().write('WARNING! validset labels is not subset of trainset\n')
            Logger().write('We overwrite all the train, valid labels to default.\n')
            Logger().write('Please create validset by sevenn_graph_build with -l\n')

        Logger().write('the validset loaded, load_dataset is now train_set\n')
        Logger().write('the ratio will be ignored\n')

        # condition: validset modalities should be subset of trainset modalities
        if config[KEY.USE_MODALITY]:
            config_modality = config[KEY.MODAL_LIST]
            valid_modality = valid_set.get_modalities()

            if set(valid_modality).issubset(set(config_modality)) is False:
                raise ValueError('validset modality is not subset of trainset')

            valid_set.write_modal_attr(
                config[KEY.MODAL_MAP],
                config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE],
            )
    else:
        train_set, valid_set, test_set = dataset.divide_dataset(
            config[KEY.RATIO], ignore_test=ignore_test
        )
        log.write(f'The dataset divided into train, valid by {KEY.RATIO}\n')

    log.format_k_v('\nloaded trainset size is', train_set.len(), write=True)
    log.format_k_v('\nloaded validset size is', valid_set.len(), write=True)

    log.write('Dataset initialization was successful\n')

    log.write('\nNumber of atoms in the train_set:\n')
    log.natoms_write(train_set.get_natoms(config[KEY.TYPE_MAP]))

    log.bar()
    log.write('Per atom energy(eV/atom) distribution:\n')
    log.statistic_write(train_set.get_statistics(KEY.PER_ATOM_ENERGY))
    log.bar()
    log.write('Force(eV/Angstrom) distribution:\n')
    log.statistic_write(train_set.get_statistics(KEY.FORCE))
    log.bar()
    log.write('Stress(eV/Angstrom^3) distribution:\n')
    try:
        log.statistic_write(train_set.get_statistics(KEY.STRESS))
    except KeyError:
        log.write('\n Stress is not included in the train_set\n')
        if is_stress:
            is_stress = False
            log.write('Turn off stress training\n')
    log.bar()

    # saved data must have atomic numbers as X not one hot idx
    if config[KEY.SAVE_BY_TRAIN_VALID]:
        train_set.save(prefix + 'train')
        valid_set.save(prefix + 'valid')
        log.format_k_v('Dataset saved by train, valid', prefix, write=True)

    # inconsistent .info dict give error when collate
    _, _ = train_set.separate_info()
    _, _ = valid_set.separate_info()

    if train_set.x_is_one_hot_idx is False:
        train_set.x_to_one_hot_idx(config[KEY.TYPE_MAP])
    if valid_set.x_is_one_hot_idx is False:
        valid_set.x_to_one_hot_idx(config[KEY.TYPE_MAP])

    log.format_k_v('training_set size', train_set.len(), write=True)
    log.format_k_v('validation_set size', valid_set.len(), write=True)

    shift, scale, conv_denominator = handle_shift_scale(
        config, train_set, checkpoint_given
    )
    config.update(
        {
            KEY.SHIFT: shift,
            KEY.SCALE: scale,
            KEY.CONV_DENOMINATOR: conv_denominator,
        }
    )

    data_lists = (train_set.to_list(), valid_set.to_list(), test_set.to_list())

    return data_lists