Unverified Commit 11adb4e7 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Refactor `NegativeSampler`. (#7001)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 366fb02f
...@@ -935,6 +935,64 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -935,6 +935,64 @@ class FusedCSCSamplingGraph(SamplingGraph):
), ),
) )
def sample_negative_edges_uniform_2(
self, edge_type, node_pairs, negative_ratio
):
"""
Sample negative edges by randomly choosing negative source-destination
edges according to a uniform distribution. For each edge ``(u, v)``,
it is supposed to generate `negative_ratio` pairs of negative edges
``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
the graph. ``u`` is exactly same as the corresponding positive edges.
It returns positive edges concatenated with negative edges. In
negative edges, negative sources are constructed from the
corresponding positive edges.
Parameters
----------
edge_type: str
The type of edges in the provided node_pairs. Any negative edges
sampled will also have the same type. If set to None, it will be
considered as a homogeneous graph.
node_pairs : torch.Tensor
A 2D tensors that represent the N pairs of positive edges in
source-destination format, with 'positive' indicating that these
edges are present in the graph. It's important to note that within
the context of a heterogeneous graph, the ids in these tensors
signify heterogeneous ids.
negative_ratio: int
The ratio of the number of negative samples to positive samples.
Returns
-------
torch.Tensor
A 2D tensors represents the N pairs of positive and negative
source-destination node pairs. In the context of a heterogeneous
graph, both the input nodes and the selected nodes are represented
by heterogeneous IDs, and the formed edges are of the input type
`edge_type`. Note that negative refers to false negatives, which
means the edge could be present or not present in the graph.
"""
if edge_type:
_, _, dst_ntype = etype_str_to_tuple(edge_type)
max_node_id = self.num_nodes[dst_ntype]
else:
max_node_id = self.total_num_nodes
pos_src = node_pairs[:, 0]
num_negative = node_pairs.shape[0] * negative_ratio
negative_seeds = (
torch.cat(
(
pos_src.repeat_interleave(negative_ratio),
torch.randint(0, max_node_id, (num_negative,)),
),
)
.view(2, num_negative)
.T
)
seeds = torch.cat((node_pairs, negative_seeds))
return seeds
def copy_to_shared_memory(self, shared_memory_name: str): def copy_to_shared_memory(self, shared_memory_name: str):
"""Copy the graph to shared memory. """Copy the graph to shared memory.
......
"""Uniform negative sampler for GraphBolt.""" """Uniform negative sampler for GraphBolt."""
import torch
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from ..negative_sampler import NegativeSampler from ..negative_sampler import NegativeSampler
...@@ -61,11 +62,33 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -61,11 +62,33 @@ class UniformNegativeSampler(NegativeSampler):
self.graph = graph self.graph = graph
def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False): def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
if not use_seeds: if use_seeds:
return self.graph.sample_negative_edges_uniform( assert node_pairs.ndim == 2 and node_pairs.shape[1] == 2, (
"Only tensor with shape N*2 is supported for negative"
+ f" sampling, but got {node_pairs.shape}."
)
# Sample negative edges, and concatenate positive edges with them.
seeds = self.graph.sample_negative_edges_uniform_2(
etype, etype,
node_pairs, node_pairs,
self.negative_ratio, self.negative_ratio,
) )
# Construct indexes for all node pairs.
num_pos_node_pairs = node_pairs.shape[0]
negative_ratio = self.negative_ratio
pos_indexes = torch.arange(0, num_pos_node_pairs)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes = torch.cat((pos_indexes, neg_indexes))
# Construct labels for all node pairs.
pos_num = node_pairs.shape[0]
neg_num = seeds.shape[0] - pos_num
labels = torch.cat(
(torch.ones(pos_num), torch.zeros(neg_num))
).bool()
return seeds, labels, indexes
else: else:
raise NotImplementedError("Not implemented yet.") return self.graph.sample_negative_edges_uniform(
etype,
node_pairs,
self.negative_ratio,
)
...@@ -38,7 +38,9 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -38,7 +38,9 @@ class NegativeSampler(MiniBatchTransformer):
def _sample(self, minibatch): def _sample(self, minibatch):
""" """
Generate a mix of positive and negative samples. Generate a mix of positive and negative samples. If `seeds` in
minibatch is not None, `labels` and `indexes` will be constructed
after negative sampling, based on corresponding seeds.
Parameters Parameters
---------- ----------
...@@ -69,12 +71,31 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -69,12 +71,31 @@ class NegativeSampler(MiniBatchTransformer):
else: else:
self._collate(minibatch, self._sample_with_etype(node_pairs)) self._collate(minibatch, self._sample_with_etype(node_pairs))
else: else:
raise NotImplementedError("Not implemented yet.") seeds = minibatch.seeds
if isinstance(seeds, Mapping):
if minibatch.indexes is None:
minibatch.indexes = {}
if minibatch.labels is None:
minibatch.labels = {}
for etype, pos_pairs in seeds.items():
(
minibatch.seeds[etype],
minibatch.labels[etype],
minibatch.indexes[etype],
) = self._sample_with_etype(pos_pairs, use_seeds=True)
else:
(
minibatch.seeds,
minibatch.labels,
minibatch.indexes,
) = self._sample_with_etype(seeds, use_seeds=True)
return minibatch return minibatch
def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False): def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
"""Generate negative pairs for a given etype form positive pairs """Generate negative pairs for a given etype form positive pairs
for a given etype. for a given etype. If `node_pairs` is a 2D tensor, which represents
`seeds` is used in minibatch, corresponding labels and indexes will be
constructed.
Parameters Parameters
---------- ----------
...@@ -87,8 +108,14 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -87,8 +108,14 @@ class NegativeSampler(MiniBatchTransformer):
Returns Returns
------- -------
Tuple[Tensor, Tensor] Tuple[Tensor, Tensor] or Tensor
A collection of negative node pairs. A collection of negative node pairs.
Tensor or None
Corresponding labels. If label is True, corresponding edge is
positive. If label is False, corresponding edge is negative.
Tensor or None
Corresponding indexes, indicates to which query an edge belongs.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -98,8 +125,8 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -98,8 +125,8 @@ class NegativeSampler(MiniBatchTransformer):
Parameters Parameters
---------- ----------
minibatch : MiniBatch minibatch : MiniBatch
The input minibatch, which contains positive node pairs, will be filled The input minibatch, which contains positive node pairs, will be
with negative information in this function. filled with negative information in this function.
neg_pairs : Tuple[Tensor, Tensor] neg_pairs : Tuple[Tensor, Tensor]
A tuple of tensors represents source-destination node pairs of A tuple of tensors represents source-destination node pairs of
negative edges, where negative means the edge may not exist in negative edges, where negative means the edge may not exist in
...@@ -107,7 +134,6 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -107,7 +134,6 @@ class NegativeSampler(MiniBatchTransformer):
etype : str etype : str
Canonical edge type. Canonical edge type.
""" """
if minibatch.seeds is None:
neg_src, neg_dst = neg_pairs neg_src, neg_dst = neg_pairs
if neg_src is not None: if neg_src is not None:
neg_src = neg_src.view(-1, self.negative_ratio) neg_src = neg_src.view(-1, self.negative_ratio)
...@@ -119,5 +145,3 @@ class NegativeSampler(MiniBatchTransformer): ...@@ -119,5 +145,3 @@ class NegativeSampler(MiniBatchTransformer):
else: else:
minibatch.negative_srcs = neg_src minibatch.negative_srcs = neg_src
minibatch.negative_dsts = neg_dst minibatch.negative_dsts = neg_dst
else:
raise NotImplementedError("Not implemented yet.")
import re
import dgl.graphbolt as gb import dgl.graphbolt as gb
import pytest import pytest
import torch import torch
...@@ -31,7 +33,7 @@ def test_NegativeSampler_invoke(): ...@@ -31,7 +33,7 @@ def test_NegativeSampler_invoke():
next(iter(negative_sampler)) next(iter(negative_sampler))
def test_UniformNegativeSampler_seeds_invoke(): def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes. # Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30 num_seeds = 30
...@@ -41,24 +43,32 @@ def test_UniformNegativeSampler_seeds_invoke(): ...@@ -41,24 +43,32 @@ def test_UniformNegativeSampler_seeds_invoke():
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
negative_ratio = 2 negative_ratio = 2
def _verify(negative_sampler):
for data in negative_sampler:
# Assertation
seeds_len = batch_size + batch_size * negative_ratio
assert data.seeds.size(0) == seeds_len
assert data.labels.size(0) == seeds_len
assert data.indexes.size(0) == seeds_len
# Invoke UniformNegativeSampler via class constructor. # Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
item_sampler, item_sampler,
graph, graph,
negative_ratio, negative_ratio,
) )
with pytest.raises(NotImplementedError): _verify(negative_sampler)
next(iter(negative_sampler))
# Invoke UniformNegativeSampler via functional form. # Invoke UniformNegativeSampler via functional form.
negative_sampler = item_sampler.sample_uniform_negative( negative_sampler = item_sampler.sample_uniform_negative(
graph, graph,
negative_ratio, negative_ratio,
) )
with pytest.raises(NotImplementedError): _verify(negative_sampler)
next(iter(negative_sampler))
def test_UniformNegativeSampler_invoke(): def test_UniformNegativeSampler_node_pairs_invoke():
# Instantiate graph and required datapipes. # Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30 num_seeds = 30
...@@ -94,7 +104,7 @@ def test_UniformNegativeSampler_invoke(): ...@@ -94,7 +104,7 @@ def test_UniformNegativeSampler_invoke():
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler(negative_ratio): def test_Uniform_NegativeSampler_node_pairs(negative_ratio):
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30 num_seeds = 30
...@@ -121,6 +131,112 @@ def test_Uniform_NegativeSampler(negative_ratio): ...@@ -121,6 +131,112 @@ def test_Uniform_NegativeSampler(negative_ratio):
assert neg_dst.numel() == batch_size * negative_ratio assert neg_dst.numel() == batch_size * negative_ratio
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler(negative_ratio):
# Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
graph,
negative_ratio,
)
# Perform Negative sampling.
for data in negative_sampler:
seeds_len = batch_size + batch_size * negative_ratio
# Assertation
assert data.seeds.size(0) == seeds_len
assert data.labels.size(0) == seeds_len
assert data.indexes.size(0) == seeds_len
# Check negative seeds value.
pos_src = data.seeds[:batch_size, 0]
neg_src = data.seeds[batch_size:, 0]
assert torch.equal(pos_src.repeat_interleave(negative_ratio), neg_src)
# Check labels.
assert torch.equal(data.labels[:batch_size], torch.ones(batch_size))
assert torch.equal(
data.labels[batch_size:], torch.zeros(batch_size * negative_ratio)
)
# Check indexes.
pos_indexes = torch.arange(0, batch_size)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
expected_indexes = torch.cat((pos_indexes, neg_indexes))
assert torch.equal(data.indexes, expected_indexes)
def test_Uniform_NegativeSampler_error_shape():
# 1. seeds with shape N*3.
# Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 3).reshape(-1, 3), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
negative_ratio = 2
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
graph,
negative_ratio,
)
with pytest.raises(
AssertionError,
match=re.escape(
"Only tensor with shape N*2 is "
+ "supported for negative sampling, but got torch.Size([10, 3])."
),
):
next(iter(negative_sampler))
# 2. seeds with shape N*2*1.
# Construct FusedCSCSamplingGraph.
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2, 1), names="seeds"
)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
graph,
negative_ratio,
)
with pytest.raises(
AssertionError,
match=re.escape(
"Only tensor with shape N*2 is "
+ "supported for negative sampling, but got torch.Size([10, 2, 1])."
),
):
next(iter(negative_sampler))
# 3. seeds with shape N.
# Construct FusedCSCSamplingGraph.
item_set = gb.ItemSet(torch.arange(0, num_seeds), names="seeds")
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
graph,
negative_ratio,
)
with pytest.raises(
AssertionError,
match=re.escape(
"Only tensor with shape N*2 is "
+ "supported for negative sampling, but got torch.Size([10])."
),
):
next(iter(negative_sampler))
def get_hetero_graph(): def get_hetero_graph():
# COO graph: # COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
...@@ -143,7 +259,7 @@ def get_hetero_graph(): ...@@ -143,7 +259,7 @@ def get_hetero_graph():
) )
def test_NegativeSampler_Hetero_Data(): def test_NegativeSampler_Hetero_node_pairs_Data():
graph = get_hetero_graph() graph = get_hetero_graph()
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
...@@ -161,3 +277,23 @@ def test_NegativeSampler_Hetero_Data(): ...@@ -161,3 +277,23 @@ def test_NegativeSampler_Hetero_Data():
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1) negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
assert len(list(negative_dp)) == 5 assert len(list(negative_dp)) == 5
def test_NegativeSampler_Hetero_Data():
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
names="seeds",
),
"n2:e2:n1": gb.ItemSet(
torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
names="seeds",
),
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
assert len(list(negative_dp)) == 5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment