sampler.py 3.04 KB
Newer Older
WangYQ's avatar
WangYQ committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import dgl
import numpy as np
import torch as th


class Sampler:
    def __init__(self,
                 graph, 
                 walk_length, 
                 num_walks, 
                 window_size,
                 num_negative):
        self.graph = graph
        self.walk_length = walk_length
        self.num_walks = num_walks
        self.window_size = window_size
        self.num_negative = num_negative
        self.node_weights = self.compute_node_sample_weight()

    def sample(self, batch, sku_info):
        """
            Given a batch of target nodes, sample postive
            pairs and negative pairs from the graph
        """
        batch = np.repeat(batch, self.num_walks)

        pos_pairs = self.generate_pos_pairs(batch)
        neg_pairs = self.generate_neg_pairs(pos_pairs)

        # get sku info with id
        srcs, dsts, labels = [], [], []
        for pair in pos_pairs + neg_pairs:
            src, dst, label = pair
            src_info = sku_info[src]
            dst_info = sku_info[dst]

            srcs.append(src_info)
            dsts.append(dst_info)
            labels.append(label)

        return th.tensor(srcs), th.tensor(dsts), th.tensor(labels)

    def filter_padding(self, traces):
        for i in range(len(traces)):
            traces[i] = [x for x in traces[i] if x != -1]

    def generate_pos_pairs(self, nodes):
        """
            For seq [1, 2, 3, 4] and node NO.2, 
            the window_size=1 will generate:
                (1, 2) and (2, 3)
        """
        # random walk
        traces, types = dgl.sampling.random_walk(
                g=self.graph,
                nodes=nodes,
                length=self.walk_length,
                prob="weight"
            )
        traces = traces.tolist()
        self.filter_padding(traces)

        # skip-gram
        pairs = []
        for trace in traces:
            for i in range(len(trace)):
                center = trace[i]
                left = max(0, i - self.window_size)
                right = min(len(trace), i + self.window_size + 1)
                pairs.extend([[center, x, 1] for x in trace[left:i]])
                pairs.extend([[center, x, 1] for x in trace[i+1:right]])
        
        return pairs

    def compute_node_sample_weight(self):
        """
            Using node degree as sample weight
        """
        return self.graph.in_degrees().float()

    def generate_neg_pairs(self, pos_pairs):
        """
            Sample based on node freq in traces, frequently shown
            nodes will have larger chance to be sampled as 
            negative node.
        """
        # sample `self.num_negative` neg dst node 
        # for each pos node pair's src node.
        negs = th.multinomial(
                self.node_weights,
                len(pos_pairs) * self.num_negative,
                replacement=True
            ).tolist()
        
        tar = np.repeat([pair[0] for pair in pos_pairs], self.num_negative)
        assert(len(tar) == len(negs))
        neg_pairs = [[x, y, 0] for x, y in zip(tar, negs)]

        return neg_pairs