node_embedding.py 2.88 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional
from ase.symbols import symbols2numbers
from e3nn.util.jit import compile_mode

import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType


# TODO: put this to model_build and do not preprocess data by onehot
@compile_mode('script')
class OnehotEmbedding(nn.Module):
    """
    x : tensor of shape (N, 1)
    x_after : tensor of shape (N, num_classes)
    It overwrite data_key_x
    and saves input to data_key_save and output to data_key_additional
    I know this is strange but it is for compatibility with previous version
    and to specie wise shift scale work
    ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2)
    """

    def __init__(
        self,
        num_classes: int,
        data_key_x: str = KEY.NODE_FEATURE,
        data_key_out: Optional[str] = None,
        data_key_save: Optional[str] = None,
        data_key_additional: Optional[str] = None,  # additional output
    ):
        super().__init__()
        self.num_classes = num_classes
        self.key_x = data_key_x
        if data_key_out is None:
            self.key_output = data_key_x
        else:
            self.key_output = data_key_out
        self.key_save = data_key_save
        self.key_additional_output = data_key_additional

    def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
        inp = data[self.key_x]
        embd = torch.nn.functional.one_hot(inp, self.num_classes)
        embd = embd.float()
        data[self.key_output] = embd
        if self.key_additional_output is not None:
            data[self.key_additional_output] = embd  # for self-connection
        if self.key_save is not None:
            data[self.key_save] = inp  # for elemwise shift scale
        return data


def get_type_mapper_from_specie(specie_list: List[str]):
    """
    from ['Hf', 'O']
    return {72: 0, 8: 1}
    """
    specie_list = sorted(specie_list)
    type_map = {}
    unique_counter = 0
    for specie in specie_list:
        atomic_num = symbols2numbers(specie)[0]
        if atomic_num in type_map:
            continue
        type_map[atomic_num] = unique_counter
        unique_counter += 1
    return type_map


# deprecated
def one_hot_atom_embedding(
    atomic_numbers: List[int], type_map: Dict[int, int]
):
    """
    atomic numbers from ase.get_atomic_numbers
    type_map from get_type_mapper_from_specie()
    """
    num_classes = len(type_map)
    try:
        type_numbers = torch.LongTensor(
            [type_map[num] for num in atomic_numbers]
        )
    except KeyError as e:
        raise ValueError(f'Atomic number {e.args[0]} is not expected')
    embd = torch.nn.functional.one_hot(type_numbers, num_classes)
    embd = embd.to(torch.get_default_dtype())

    return embd