layers.py 6.03 KB
Newer Older
1
2
3
import torch
import torch.nn as nn
import torch.nn.functional as F
4

5
6
import dgl
import dgl.function as fn
7
8
import dgl.nn.pytorch as dglnn

9
10
11
12
13

def disable_grad(module):
    for param in module.parameters():
        param.requires_grad = False

14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def _init_input_modules(g, ntype, textset, hidden_dims):
    # We initialize the linear projections of each input feature ``x`` as
    # follows:
    # * If ``x`` is a scalar integral feature, we assume that ``x`` is a categorical
    #   feature, and assume the range of ``x`` is 0..max(x).
    # * If ``x`` is a float one-dimensional feature, we assume that ``x`` is a
    #   numeric vector.
    # * If ``x`` is a field of a textset, we process it as bag of words.
    module_dict = nn.ModuleDict()

    for column, data in g.nodes[ntype].data.items():
        if column == dgl.NID:
            continue
        if data.dtype == torch.float32:
            assert data.ndim == 2
            m = nn.Linear(data.shape[1], hidden_dims)
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0)
            module_dict[column] = m
        elif data.dtype == torch.int64:
            assert data.ndim == 1
36
            m = nn.Embedding(data.max() + 2, hidden_dims, padding_idx=-1)
37
38
39
40
            nn.init.xavier_uniform_(m.weight)
            module_dict[column] = m

    if textset is not None:
41
        for column, field in textset.items():
42
            textlist, vocab, pad_var, batch_first = field
43
            module_dict[column] = BagOfWords(vocab, hidden_dims)
44
45
46

    return module_dict

47

48
class BagOfWords(nn.Module):
49
    def __init__(self, vocab, hidden_dims):
50
51
52
        super().__init__()

        self.emb = nn.Embedding(
53
54
55
56
            len(vocab.get_itos()),
            hidden_dims,
            padding_idx=vocab.get_stoi()["<pad>"],
        )
57
58
59
60
61
        nn.init.xavier_uniform_(self.emb.weight)

    def forward(self, x, length):
        return self.emb(x).sum(1) / length.unsqueeze(1).float()

62

63
64
65
66
class LinearProjector(nn.Module):
    """
    Projects each input feature of the graph linearly and sums them up
    """
67

68
69
70
71
    def __init__(self, full_graph, ntype, textset, hidden_dims):
        super().__init__()

        self.ntype = ntype
72
73
74
        self.inputs = _init_input_modules(
            full_graph, ntype, textset, hidden_dims
        )
75
76
77
78

    def forward(self, ndata):
        projections = []
        for feature, data in ndata.items():
79
            if feature == dgl.NID or feature.endswith("__len"):
80
81
82
83
84
                # This is an additional feature indicating the length of the ``feature``
                # column; we shouldn't process this.
                continue

            module = self.inputs[feature]
85
            if isinstance(module, BagOfWords):
86
                # Textual feature; find the length and pass it to the textual module.
87
                length = ndata[feature + "__len"]
88
89
90
91
92
93
94
                result = module(data, length)
            else:
                result = module(data)
            projections.append(result)

        return torch.stack(projections, 1).sum(1)

95

96
97
98
99
100
101
102
103
104
105
106
class WeightedSAGEConv(nn.Module):
    def __init__(self, input_dims, hidden_dims, output_dims, act=F.relu):
        super().__init__()

        self.act = act
        self.Q = nn.Linear(input_dims, hidden_dims)
        self.W = nn.Linear(input_dims + hidden_dims, output_dims)
        self.reset_parameters()
        self.dropout = nn.Dropout(0.5)

    def reset_parameters(self):
107
        gain = nn.init.calculate_gain("relu")
108
109
110
111
112
113
114
115
116
117
118
119
120
        nn.init.xavier_uniform_(self.Q.weight, gain=gain)
        nn.init.xavier_uniform_(self.W.weight, gain=gain)
        nn.init.constant_(self.Q.bias, 0)
        nn.init.constant_(self.W.bias, 0)

    def forward(self, g, h, weights):
        """
        g : graph
        h : node features
        weights : scalar edge weights
        """
        h_src, h_dst = h
        with g.local_scope():
121
122
123
124
125
126
            g.srcdata["n"] = self.act(self.Q(self.dropout(h_src)))
            g.edata["w"] = weights.float()
            g.update_all(fn.u_mul_e("n", "w", "m"), fn.sum("m", "n"))
            g.update_all(fn.copy_e("w", "m"), fn.sum("m", "ws"))
            n = g.dstdata["n"]
            ws = g.dstdata["ws"].unsqueeze(1).clamp(min=1)
127
128
            z = self.act(self.W(self.dropout(torch.cat([n / ws, h_dst], 1))))
            z_norm = z.norm(2, 1, keepdim=True)
129
130
131
            z_norm = torch.where(
                z_norm == 0, torch.tensor(1.0).to(z_norm), z_norm
            )
132
133
134
            z = z / z_norm
            return z

135

136
137
138
139
140
141
142
143
144
145
146
147
148
class SAGENet(nn.Module):
    def __init__(self, hidden_dims, n_layers):
        """
        g : DGLHeteroGraph
            The user-item interaction graph.
            This is only for finding the range of categorical variables.
        item_textsets : torchtext.data.Dataset
            The textual features of each item node.
        """
        super().__init__()

        self.convs = nn.ModuleList()
        for _ in range(n_layers):
149
150
151
            self.convs.append(
                WeightedSAGEConv(hidden_dims, hidden_dims, hidden_dims)
            )
152
153
154

    def forward(self, blocks, h):
        for layer, block in zip(self.convs, blocks):
155
156
            h_dst = h[: block.num_nodes("DST/" + block.ntypes[0])]
            h = layer(block, (h, h_dst), block.edata["weights"])
157
158
        return h

159

160
161
162
163
class ItemToItemScorer(nn.Module):
    def __init__(self, full_graph, ntype):
        super().__init__()

164
        n_nodes = full_graph.num_nodes(ntype)
165
        self.bias = nn.Parameter(torch.zeros(n_nodes, 1))
166
167
168
169

    def _add_bias(self, edges):
        bias_src = self.bias[edges.src[dgl.NID]]
        bias_dst = self.bias[edges.dst[dgl.NID]]
170
        return {"s": edges.data["s"] + bias_src + bias_dst}
171
172
173
174
175
176
177

    def forward(self, item_item_graph, h):
        """
        item_item_graph : graph consists of edges connecting the pairs
        h : hidden state of every node
        """
        with item_item_graph.local_scope():
178
179
            item_item_graph.ndata["h"] = h
            item_item_graph.apply_edges(fn.u_dot_v("h", "h", "s"))
180
            item_item_graph.apply_edges(self._add_bias)
181
            pair_score = item_item_graph.edata["s"]
182
        return pair_score