self_connection.py 3.95 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch.nn as nn
from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear
from e3nn.util.jit import compile_mode

import sevenn._keys as KEY
from sevenn._const import AtomGraphDataType


@compile_mode('script')
class SelfConnectionIntro(nn.Module):
    """
    do TensorProduct of x and some data(here attribute of x)
    and save it (to concatenate updated x at SelfConnectionOutro)
    """

    def __init__(
        self,
        irreps_in: Irreps,
        irreps_operand: Irreps,
        irreps_out: Irreps,
        data_key_x: str = KEY.NODE_FEATURE,
        data_key_operand: str = KEY.NODE_ATTR,
        lazy_layer_instantiate: bool = True,
        **kwargs,  # for compatibility
    ):
        super().__init__()

        self.fc_tensor_product = FullyConnectedTensorProduct(
            irreps_in, irreps_operand, irreps_out
        )
        self.irreps_in1 = irreps_in
        self.irreps_in2 = irreps_operand
        self.irreps_out = irreps_out

        self.key_x = data_key_x
        self.key_operand = data_key_operand

        self.fc_tensor_product = None
        self.layer_instantiated = False
        self.fc_tensor_product_cls = FullyConnectedTensorProduct
        self.fc_tensor_product_kwargs = kwargs

        if not lazy_layer_instantiate:
            self.instantiate()

    def instantiate(self):
        if self.fc_tensor_product is not None:
            raise ValueError('fc_tensor_product layer already exists')
        self.fc_tensor_product = self.fc_tensor_product_cls(
            self.irreps_in1,
            self.irreps_in2,
            self.irreps_out,
            shared_weights=True,
            internal_weights=None,  # same as True
            **self.fc_tensor_product_kwargs,
        )
        self.layer_instantiated = True

    def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
        assert self.fc_tensor_product is not None, 'Layer is not instantiated'
        data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product(
            data[self.key_x], data[self.key_operand]
        )
        return data


@compile_mode('script')
class SelfConnectionLinearIntro(nn.Module):
    """
    Linear style self connection update
    """

    def __init__(
        self,
        irreps_in: Irreps,
        irreps_out: Irreps,
        data_key_x: str = KEY.NODE_FEATURE,
        lazy_layer_instantiate: bool = True,
        **kwargs,
    ):
        super().__init__()
        self.irreps_in = irreps_in
        self.irreps_out = irreps_out
        self.key_x = data_key_x

        self.linear = None
        self.layer_instantiated = False
        self.linear_cls = Linear

        # TODO: better to have SelfConnectionIntro super class
        kwargs.pop('irreps_operand')
        self.linear_kwargs = kwargs

        if not lazy_layer_instantiate:
            self.instantiate()

    def instantiate(self):
        if self.linear is not None:
            raise ValueError('Linear layer already exists')
        self.linear = self.linear_cls(
            self.irreps_in, self.irreps_out, **self.linear_kwargs
        )
        self.layer_instantiated = True

    def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
        assert self.linear is not None, 'Layer is not instantiated'
        data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x])
        return data


@compile_mode('script')
class SelfConnectionOutro(nn.Module):
    """
    do TensorProduct of x and some data(here attribute of x)
    and save it (to concatenate updated x at SelfConnectionOutro)
    """

    def __init__(
        self,
        data_key_x: str = KEY.NODE_FEATURE,
    ):
        super().__init__()
        self.key_x = data_key_x

    def forward(self, data: AtomGraphDataType) -> AtomGraphDataType:
        data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP]
        del data[KEY.SELF_CONNECTION_TEMP]
        return data