import dgl import torch as th class NegativeSampler(object): def __init__(self, g, k, neg_share=False, device=None): if device is None: device = g.device self.weights = g.in_degrees().float().to(device) ** 0.75 self.k = k self.neg_share = neg_share def __call__(self, g, eids): src, _ = g.find_edges(eids) n = len(src) if self.neg_share and n % self.k == 0: dst = self.weights.multinomial(n, replacement=True) dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten() else: dst = self.weights.multinomial(n * self.k, replacement=True) src = src.repeat_interleave(self.k) return src, dst