Unverified Commit af0b63ed authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix gpu `NegativeSampler` for seeds. (#7068)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 3d854a6b
......@@ -1026,7 +1026,13 @@ class FusedCSCSamplingGraph(SamplingGraph):
torch.cat(
(
pos_src.repeat_interleave(negative_ratio),
torch.randint(0, max_node_id, (num_negative,)),
torch.randint(
0,
max_node_id,
(num_negative,),
dtype=node_pairs.dtype,
device=node_pairs.device,
),
),
)
.view(2, num_negative)
......
......@@ -76,15 +76,30 @@ class UniformNegativeSampler(NegativeSampler):
# Construct indexes for all node pairs.
num_pos_node_pairs = node_pairs.shape[0]
negative_ratio = self.negative_ratio
pos_indexes = torch.arange(0, num_pos_node_pairs)
pos_indexes = torch.arange(
0,
num_pos_node_pairs,
device=seeds.device,
)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes = torch.cat((pos_indexes, neg_indexes))
# Construct labels for all node pairs.
pos_num = node_pairs.shape[0]
neg_num = seeds.shape[0] - pos_num
labels = torch.cat(
(torch.ones(pos_num), torch.zeros(neg_num))
).bool()
(
torch.ones(
pos_num,
dtype=torch.bool,
device=seeds.device,
),
torch.zeros(
neg_num,
dtype=torch.bool,
device=seeds.device,
),
),
)
return seeds, labels, indexes
else:
return self.graph.sample_negative_edges_uniform(
......
import re
import backend as F
import dgl.graphbolt as gb
import pytest
import torch
......@@ -14,7 +16,9 @@ def test_NegativeSampler_invoke():
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2
# Invoke NegativeSampler via class constructor.
......@@ -35,13 +39,17 @@ def test_NegativeSampler_invoke():
def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2
def _verify(negative_sampler):
......@@ -70,13 +78,17 @@ def test_UniformNegativeSampler_invoke():
def test_UniformNegativeSampler_node_pairs_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2
# Verify iteration over UniformNegativeSampler.
......@@ -106,13 +118,17 @@ def test_UniformNegativeSampler_node_pairs_invoke():
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler_node_pairs(negative_ratio):
# Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
......@@ -134,13 +150,17 @@ def test_Uniform_NegativeSampler_node_pairs(negative_ratio):
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler(negative_ratio):
# Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
......@@ -159,12 +179,15 @@ def test_Uniform_NegativeSampler(negative_ratio):
neg_src = data.seeds[batch_size:, 0]
assert torch.equal(pos_src.repeat_interleave(negative_ratio), neg_src)
# Check labels.
assert torch.equal(data.labels[:batch_size], torch.ones(batch_size))
assert torch.equal(
data.labels[batch_size:], torch.zeros(batch_size * negative_ratio)
data.labels[:batch_size], torch.ones(batch_size).to(F.ctx())
)
assert torch.equal(
data.labels[batch_size:],
torch.zeros(batch_size * negative_ratio).to(F.ctx()),
)
# Check indexes.
pos_indexes = torch.arange(0, batch_size)
pos_indexes = torch.arange(0, batch_size).to(F.ctx())
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
expected_indexes = torch.cat((pos_indexes, neg_indexes))
assert torch.equal(data.indexes, expected_indexes)
......@@ -173,13 +196,17 @@ def test_Uniform_NegativeSampler(negative_ratio):
def test_Uniform_NegativeSampler_error_shape():
# 1. seeds with shape N*3.
# Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 3).reshape(-1, 3), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
......@@ -201,7 +228,9 @@ def test_Uniform_NegativeSampler_error_shape():
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2, 1), names="seeds"
)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
......@@ -220,7 +249,9 @@ def test_Uniform_NegativeSampler_error_shape():
# 3. seeds with shape N.
# Construct FusedCSCSamplingGraph.
item_set = gb.ItemSet(torch.arange(0, num_seeds), names="seeds")
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
......@@ -260,7 +291,7 @@ def get_hetero_graph():
def test_NegativeSampler_Hetero_node_pairs_Data():
graph = get_hetero_graph()
graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
......@@ -274,13 +305,13 @@ def test_NegativeSampler_Hetero_node_pairs_Data():
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
assert len(list(negative_dp)) == 5
def test_NegativeSampler_Hetero_Data():
graph = get_hetero_graph()
graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
......@@ -295,7 +326,9 @@ def test_NegativeSampler_Hetero_Data():
)
batch_size = 2
negative_ratio = 1
item_sampler = gb.ItemSampler(itemset, batch_size=batch_size)
item_sampler = gb.ItemSampler(itemset, batch_size=batch_size).copy_to(
F.ctx()
)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, negative_ratio)
assert len(list(negative_dp)) == 5
# Perform negative sampling.
......@@ -311,5 +344,5 @@ def test_NegativeSampler_Hetero_Data():
for etype, seeds_data in data.seeds.items():
neg_src = seeds_data[batch_size:, 0]
neg_dst = seeds_data[batch_size:, 1]
assert torch.equal(expected_neg_src[i][etype], neg_src)
assert torch.equal(expected_neg_src[i][etype].to(F.ctx()), neg_src)
assert (neg_dst < 3).all(), neg_dst
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment