layers.py 6.02 KB
Newer Older
1
2
import dgl
import dgl.function as fn
3
import dgl.nn.pytorch as dglnn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
6
import torch
import torch.nn as nn
import torch.nn.functional as F
7

8
9
10
11
12

def disable_grad(module):
    for param in module.parameters():
        param.requires_grad = False

13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def _init_input_modules(g, ntype, textset, hidden_dims):
    # We initialize the linear projections of each input feature ``x`` as
    # follows:
    # * If ``x`` is a scalar integral feature, we assume that ``x`` is a categorical
    #   feature, and assume the range of ``x`` is 0..max(x).
    # * If ``x`` is a float one-dimensional feature, we assume that ``x`` is a
    #   numeric vector.
    # * If ``x`` is a field of a textset, we process it as bag of words.
    module_dict = nn.ModuleDict()

    for column, data in g.nodes[ntype].data.items():
        if column == dgl.NID:
            continue
        if data.dtype == torch.float32:
            assert data.ndim == 2
            m = nn.Linear(data.shape[1], hidden_dims)
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0)
            module_dict[column] = m
        elif data.dtype == torch.int64:
            assert data.ndim == 1
35
            m = nn.Embedding(data.max() + 2, hidden_dims, padding_idx=-1)
36
37
38
39
            nn.init.xavier_uniform_(m.weight)
            module_dict[column] = m

    if textset is not None:
40
        for column, field in textset.items():
41
            textlist, vocab, pad_var, batch_first = field
42
            module_dict[column] = BagOfWords(vocab, hidden_dims)
43
44
45

    return module_dict

46

47
class BagOfWords(nn.Module):
48
    def __init__(self, vocab, hidden_dims):
49
50
51
        super().__init__()

        self.emb = nn.Embedding(
52
53
54
55
            len(vocab.get_itos()),
            hidden_dims,
            padding_idx=vocab.get_stoi()["<pad>"],
        )
56
57
58
59
60
        nn.init.xavier_uniform_(self.emb.weight)

    def forward(self, x, length):
        return self.emb(x).sum(1) / length.unsqueeze(1).float()

61

62
63
64
65
class LinearProjector(nn.Module):
    """
    Projects each input feature of the graph linearly and sums them up
    """
66

67
68
69
70
    def __init__(self, full_graph, ntype, textset, hidden_dims):
        super().__init__()

        self.ntype = ntype
71
72
73
        self.inputs = _init_input_modules(
            full_graph, ntype, textset, hidden_dims
        )
74
75
76
77

    def forward(self, ndata):
        projections = []
        for feature, data in ndata.items():
78
            if feature == dgl.NID or feature.endswith("__len"):
79
80
81
82
83
                # This is an additional feature indicating the length of the ``feature``
                # column; we shouldn't process this.
                continue

            module = self.inputs[feature]
84
            if isinstance(module, BagOfWords):
85
                # Textual feature; find the length and pass it to the textual module.
86
                length = ndata[feature + "__len"]
87
88
89
90
91
92
93
                result = module(data, length)
            else:
                result = module(data)
            projections.append(result)

        return torch.stack(projections, 1).sum(1)

94

95
96
97
98
99
100
101
102
103
104
105
class WeightedSAGEConv(nn.Module):
    def __init__(self, input_dims, hidden_dims, output_dims, act=F.relu):
        super().__init__()

        self.act = act
        self.Q = nn.Linear(input_dims, hidden_dims)
        self.W = nn.Linear(input_dims + hidden_dims, output_dims)
        self.reset_parameters()
        self.dropout = nn.Dropout(0.5)

    def reset_parameters(self):
106
        gain = nn.init.calculate_gain("relu")
107
108
109
110
111
112
113
114
115
116
117
118
119
        nn.init.xavier_uniform_(self.Q.weight, gain=gain)
        nn.init.xavier_uniform_(self.W.weight, gain=gain)
        nn.init.constant_(self.Q.bias, 0)
        nn.init.constant_(self.W.bias, 0)

    def forward(self, g, h, weights):
        """
        g : graph
        h : node features
        weights : scalar edge weights
        """
        h_src, h_dst = h
        with g.local_scope():
120
121
122
123
124
125
            g.srcdata["n"] = self.act(self.Q(self.dropout(h_src)))
            g.edata["w"] = weights.float()
            g.update_all(fn.u_mul_e("n", "w", "m"), fn.sum("m", "n"))
            g.update_all(fn.copy_e("w", "m"), fn.sum("m", "ws"))
            n = g.dstdata["n"]
            ws = g.dstdata["ws"].unsqueeze(1).clamp(min=1)
126
127
            z = self.act(self.W(self.dropout(torch.cat([n / ws, h_dst], 1))))
            z_norm = z.norm(2, 1, keepdim=True)
128
129
130
            z_norm = torch.where(
                z_norm == 0, torch.tensor(1.0).to(z_norm), z_norm
            )
131
132
133
            z = z / z_norm
            return z

134

135
136
137
class SAGENet(nn.Module):
    def __init__(self, hidden_dims, n_layers):
        """
peizhou001's avatar
peizhou001 committed
138
        g : DGLGraph
139
140
141
142
143
144
145
146
147
            The user-item interaction graph.
            This is only for finding the range of categorical variables.
        item_textsets : torchtext.data.Dataset
            The textual features of each item node.
        """
        super().__init__()

        self.convs = nn.ModuleList()
        for _ in range(n_layers):
148
149
150
            self.convs.append(
                WeightedSAGEConv(hidden_dims, hidden_dims, hidden_dims)
            )
151
152
153

    def forward(self, blocks, h):
        for layer, block in zip(self.convs, blocks):
154
155
            h_dst = h[: block.num_nodes("DST/" + block.ntypes[0])]
            h = layer(block, (h, h_dst), block.edata["weights"])
156
157
        return h

158

159
160
161
162
class ItemToItemScorer(nn.Module):
    def __init__(self, full_graph, ntype):
        super().__init__()

163
        n_nodes = full_graph.num_nodes(ntype)
164
        self.bias = nn.Parameter(torch.zeros(n_nodes, 1))
165
166
167
168

    def _add_bias(self, edges):
        bias_src = self.bias[edges.src[dgl.NID]]
        bias_dst = self.bias[edges.dst[dgl.NID]]
169
        return {"s": edges.data["s"] + bias_src + bias_dst}
170
171
172
173
174
175
176

    def forward(self, item_item_graph, h):
        """
        item_item_graph : graph consists of edges connecting the pairs
        h : hidden state of every node
        """
        with item_item_graph.local_scope():
177
178
            item_item_graph.ndata["h"] = h
            item_item_graph.apply_edges(fn.u_dot_v("h", "h", "s"))
179
            item_item_graph.apply_edges(self._add_bias)
180
            pair_score = item_item_graph.edata["s"]
181
        return pair_score