sampler.py 7.31 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl
2
3
import numpy as np
import torch
4
from torch.utils.data import DataLoader, IterableDataset
5
6
from torchtext.data.functional import numericalize_tokens_from_iterator

7

8
9
10
11
12
13
14
15
16
17
18
def padding(array, yy, val):
    """
    :param array: torch tensor array
    :param yy: desired width
    :param val: padded value
    :return: padded array
    """
    w = array.shape[0]
    b = 0
    bb = yy - b - w

19
20
21
22
    return torch.nn.functional.pad(
        array, pad=(b, bb), mode="constant", value=val
    )

23
24
25
26

def compact_and_copy(frontier, seeds):
    block = dgl.to_block(frontier, seeds)
    for col, data in frontier.edata.items():
27
28
29
        if col == dgl.EID:
            continue
        block.edata[col] = data[block.edata[dgl.EID]]
30
31
    return block

32

33
34
35
36
37
class ItemToItemBatchSampler(IterableDataset):
    def __init__(self, g, user_type, item_type, batch_size):
        self.g = g
        self.user_type = user_type
        self.item_type = item_type
38
39
        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
40
41
42
43
        self.batch_size = batch_size

    def __iter__(self):
        while True:
44
45
46
            heads = torch.randint(
                0, self.g.num_nodes(self.item_type), (self.batch_size,)
            )
47
48
49
            tails = dgl.sampling.random_walk(
                self.g,
                heads,
50
51
52
53
54
                metapath=[self.item_to_user_etype, self.user_to_item_etype],
            )[0][:, 2]
            neg_tails = torch.randint(
                0, self.g.num_nodes(self.item_type), (self.batch_size,)
            )
55

56
            mask = tails != -1
57
58
            yield heads[mask], tails[mask], neg_tails[mask]

59

60
class NeighborSampler(object):
61
62
63
64
65
66
67
68
69
70
71
    def __init__(
        self,
        g,
        user_type,
        item_type,
        random_walk_length,
        random_walk_restart_prob,
        num_random_walks,
        num_neighbors,
        num_layers,
    ):
72
73
74
        self.g = g
        self.user_type = user_type
        self.item_type = item_type
75
76
        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
77
        self.samplers = [
78
79
80
81
82
83
84
85
86
87
88
            dgl.sampling.PinSAGESampler(
                g,
                item_type,
                user_type,
                random_walk_length,
                random_walk_restart_prob,
                num_random_walks,
                num_neighbors,
            )
            for _ in range(num_layers)
        ]
89
90
91
92
93
94

    def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None):
        blocks = []
        for sampler in self.samplers:
            frontier = sampler(seeds)
            if heads is not None:
95
96
97
98
99
                eids = frontier.edge_ids(
                    torch.cat([heads, heads]),
                    torch.cat([tails, neg_tails]),
                    return_uv=True,
                )[2]
100
101
102
                if len(eids) > 0:
                    old_frontier = frontier
                    frontier = dgl.remove_edges(old_frontier, eids)
103
104
105
106
                    # print(old_frontier)
                    # print(frontier)
                    # print(frontier.edata['weights'])
                    # frontier.edata['weights'] = old_frontier.edata['weights'][frontier.edata[dgl.EID]]
107
108
109
110
111
112
113
114
115
            block = compact_and_copy(frontier, seeds)
            seeds = block.srcdata[dgl.NID]
            blocks.insert(0, block)
        return blocks

    def sample_from_item_pairs(self, heads, tails, neg_tails):
        # Create a graph with positive connections only and another graph with negative
        # connections only.
        pos_graph = dgl.graph(
116
117
            (heads, tails), num_nodes=self.g.num_nodes(self.item_type)
        )
118
        neg_graph = dgl.graph(
119
120
            (heads, neg_tails), num_nodes=self.g.num_nodes(self.item_type)
        )
121
122
123
124
125
126
        pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
        seeds = pos_graph.ndata[dgl.NID]

        blocks = self.sample_blocks(seeds, heads, tails, neg_tails)
        return pos_graph, neg_graph, blocks

127

128
129
130
131
132
133
134
135
136
137
def assign_simple_node_features(ndata, g, ntype, assign_id=False):
    """
    Copies data to the given block from the corresponding nodes in the original graph.
    """
    for col in g.nodes[ntype].data.keys():
        if not assign_id and col == dgl.NID:
            continue
        induced_nodes = ndata[dgl.NID]
        ndata[col] = g.nodes[ntype].data[col][induced_nodes]

138

139
140
141
142
143
144
145
146
147
148
def assign_textual_node_features(ndata, textset, ntype):
    """
    Assigns numericalized tokens from a torchtext dataset to given block.

    The numericalized tokens would be stored in the block as node features
    with the same name as ``field_name``.

    The length would be stored as another node feature with name
    ``field_name + '__len'``.

peizhou001's avatar
peizhou001 committed
149
    block : DGLGraph
150
151
152
153
154
155
156
157
158
159
160
161
        First element of the compacted blocks, with "dgl.NID" as the
        corresponding node ID in the original graph, hence the index to the
        text dataset.

        The numericalized tokens (and lengths if available) would be stored
        onto the blocks as new node features.
    textset : torchtext.data.Dataset
        A torchtext dataset whose number of examples is the same as that
        of nodes in the original graph.
    """
    node_ids = ndata[dgl.NID].numpy()

162
163
    for field_name, field in textset.items():
        textlist, vocab, pad_var, batch_first = field
164

165
166
        examples = [textlist[i] for i in node_ids]
        ids_iter = numericalize_tokens_from_iterator(vocab, examples)
167

168
169
170
171
172
173
174
175
176
177
        maxsize = max([len(textlist[i]) for i in node_ids])
        ids = next(ids_iter)
        x = torch.asarray([num for num in ids])
        lengths = torch.tensor([len(x)])
        tokens = padding(x, maxsize, pad_var)

        for ids in ids_iter:
            x = torch.asarray([num for num in ids])
            l = torch.tensor([len(x)])
            y = padding(x, maxsize, pad_var)
178
            tokens = torch.vstack((tokens, y))
179
            lengths = torch.cat((lengths, l))
180

181
        if not batch_first:
182
183
184
            tokens = tokens.t()

        ndata[field_name] = tokens
185
186
        ndata[field_name + "__len"] = lengths

187
188
189
190
191
192
193
194
195

def assign_features_to_blocks(blocks, g, textset, ntype):
    # For the first block (which is closest to the input), copy the features from
    # the original graph as well as the texts.
    assign_simple_node_features(blocks[0].srcdata, g, ntype)
    assign_textual_node_features(blocks[0].srcdata, textset, ntype)
    assign_simple_node_features(blocks[-1].dstdata, g, ntype)
    assign_textual_node_features(blocks[-1].dstdata, textset, ntype)

196

197
198
199
200
201
202
203
204
205
206
class PinSAGECollator(object):
    def __init__(self, sampler, g, ntype, textset):
        self.sampler = sampler
        self.ntype = ntype
        self.g = g
        self.textset = textset

    def collate_train(self, batches):
        heads, tails, neg_tails = batches[0]
        # Construct multilayer neighborhood via PinSAGE...
207
208
209
        pos_graph, neg_graph, blocks = self.sampler.sample_from_item_pairs(
            heads, tails, neg_tails
        )
210
211
212
213
214
215
216
217
218
        assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)

        return pos_graph, neg_graph, blocks

    def collate_test(self, samples):
        batch = torch.LongTensor(samples)
        blocks = self.sampler.sample_blocks(batch)
        assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)
        return blocks