Unverified Commit af0b63ed authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix gpu `NegativeSampler` for seeds. (#7068)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 3d854a6b
...@@ -1026,7 +1026,13 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -1026,7 +1026,13 @@ class FusedCSCSamplingGraph(SamplingGraph):
torch.cat( torch.cat(
( (
pos_src.repeat_interleave(negative_ratio), pos_src.repeat_interleave(negative_ratio),
torch.randint(0, max_node_id, (num_negative,)), torch.randint(
0,
max_node_id,
(num_negative,),
dtype=node_pairs.dtype,
device=node_pairs.device,
),
), ),
) )
.view(2, num_negative) .view(2, num_negative)
......
...@@ -76,15 +76,30 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -76,15 +76,30 @@ class UniformNegativeSampler(NegativeSampler):
# Construct indexes for all node pairs. # Construct indexes for all node pairs.
num_pos_node_pairs = node_pairs.shape[0] num_pos_node_pairs = node_pairs.shape[0]
negative_ratio = self.negative_ratio negative_ratio = self.negative_ratio
pos_indexes = torch.arange(0, num_pos_node_pairs) pos_indexes = torch.arange(
0,
num_pos_node_pairs,
device=seeds.device,
)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio) neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes = torch.cat((pos_indexes, neg_indexes)) indexes = torch.cat((pos_indexes, neg_indexes))
# Construct labels for all node pairs. # Construct labels for all node pairs.
pos_num = node_pairs.shape[0] pos_num = node_pairs.shape[0]
neg_num = seeds.shape[0] - pos_num neg_num = seeds.shape[0] - pos_num
labels = torch.cat( labels = torch.cat(
(torch.ones(pos_num), torch.zeros(neg_num)) (
).bool() torch.ones(
pos_num,
dtype=torch.bool,
device=seeds.device,
),
torch.zeros(
neg_num,
dtype=torch.bool,
device=seeds.device,
),
),
)
return seeds, labels, indexes return seeds, labels, indexes
else: else:
return self.graph.sample_negative_edges_uniform( return self.graph.sample_negative_edges_uniform(
......
import re import re
import backend as F
import dgl.graphbolt as gb import dgl.graphbolt as gb
import pytest import pytest
import torch import torch
...@@ -14,7 +16,9 @@ def test_NegativeSampler_invoke(): ...@@ -14,7 +16,9 @@ def test_NegativeSampler_invoke():
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs" torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2 negative_ratio = 2
# Invoke NegativeSampler via class constructor. # Invoke NegativeSampler via class constructor.
...@@ -35,13 +39,17 @@ def test_NegativeSampler_invoke(): ...@@ -35,13 +39,17 @@ def test_NegativeSampler_invoke():
def test_UniformNegativeSampler_invoke(): def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes. # Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="seeds" torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="seeds"
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2 negative_ratio = 2
def _verify(negative_sampler): def _verify(negative_sampler):
...@@ -70,13 +78,17 @@ def test_UniformNegativeSampler_invoke(): ...@@ -70,13 +78,17 @@ def test_UniformNegativeSampler_invoke():
def test_UniformNegativeSampler_node_pairs_invoke(): def test_UniformNegativeSampler_node_pairs_invoke():
# Instantiate graph and required datapipes. # Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs" torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2 negative_ratio = 2
# Verify iteration over UniformNegativeSampler. # Verify iteration over UniformNegativeSampler.
...@@ -106,13 +118,17 @@ def test_UniformNegativeSampler_node_pairs_invoke(): ...@@ -106,13 +118,17 @@ def test_UniformNegativeSampler_node_pairs_invoke():
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler_node_pairs(negative_ratio): def test_Uniform_NegativeSampler_node_pairs(negative_ratio):
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs" torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
item_sampler, item_sampler,
...@@ -134,13 +150,17 @@ def test_Uniform_NegativeSampler_node_pairs(negative_ratio): ...@@ -134,13 +150,17 @@ def test_Uniform_NegativeSampler_node_pairs(negative_ratio):
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler(negative_ratio): def test_Uniform_NegativeSampler(negative_ratio):
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="seeds" torch.arange(0, num_seeds * 2).reshape(-1, 2), names="seeds"
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
item_sampler, item_sampler,
...@@ -159,12 +179,15 @@ def test_Uniform_NegativeSampler(negative_ratio): ...@@ -159,12 +179,15 @@ def test_Uniform_NegativeSampler(negative_ratio):
neg_src = data.seeds[batch_size:, 0] neg_src = data.seeds[batch_size:, 0]
assert torch.equal(pos_src.repeat_interleave(negative_ratio), neg_src) assert torch.equal(pos_src.repeat_interleave(negative_ratio), neg_src)
# Check labels. # Check labels.
assert torch.equal(data.labels[:batch_size], torch.ones(batch_size))
assert torch.equal( assert torch.equal(
data.labels[batch_size:], torch.zeros(batch_size * negative_ratio) data.labels[:batch_size], torch.ones(batch_size).to(F.ctx())
)
assert torch.equal(
data.labels[batch_size:],
torch.zeros(batch_size * negative_ratio).to(F.ctx()),
) )
# Check indexes. # Check indexes.
pos_indexes = torch.arange(0, batch_size) pos_indexes = torch.arange(0, batch_size).to(F.ctx())
neg_indexes = pos_indexes.repeat_interleave(negative_ratio) neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
expected_indexes = torch.cat((pos_indexes, neg_indexes)) expected_indexes = torch.cat((pos_indexes, neg_indexes))
assert torch.equal(data.indexes, expected_indexes) assert torch.equal(data.indexes, expected_indexes)
...@@ -173,13 +196,17 @@ def test_Uniform_NegativeSampler(negative_ratio): ...@@ -173,13 +196,17 @@ def test_Uniform_NegativeSampler(negative_ratio):
def test_Uniform_NegativeSampler_error_shape(): def test_Uniform_NegativeSampler_error_shape():
# 1. seeds with shape N*3. # 1. seeds with shape N*3.
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
torch.arange(0, num_seeds * 3).reshape(-1, 3), names="seeds" torch.arange(0, num_seeds * 3).reshape(-1, 3), names="seeds"
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2 negative_ratio = 2
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
...@@ -201,7 +228,9 @@ def test_Uniform_NegativeSampler_error_shape(): ...@@ -201,7 +228,9 @@ def test_Uniform_NegativeSampler_error_shape():
item_set = gb.ItemSet( item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2, 1), names="seeds" torch.arange(0, num_seeds * 2).reshape(-1, 2, 1), names="seeds"
) )
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
item_sampler, item_sampler,
...@@ -220,7 +249,9 @@ def test_Uniform_NegativeSampler_error_shape(): ...@@ -220,7 +249,9 @@ def test_Uniform_NegativeSampler_error_shape():
# 3. seeds with shape N. # 3. seeds with shape N.
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
item_set = gb.ItemSet(torch.arange(0, num_seeds), names="seeds") item_set = gb.ItemSet(torch.arange(0, num_seeds), names="seeds")
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
item_sampler, item_sampler,
...@@ -260,7 +291,7 @@ def get_hetero_graph(): ...@@ -260,7 +291,7 @@ def get_hetero_graph():
def test_NegativeSampler_Hetero_node_pairs_Data(): def test_NegativeSampler_Hetero_node_pairs_Data():
graph = get_hetero_graph() graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1:e1:n2": gb.ItemSet( "n1:e1:n2": gb.ItemSet(
...@@ -274,13 +305,13 @@ def test_NegativeSampler_Hetero_node_pairs_Data(): ...@@ -274,13 +305,13 @@ def test_NegativeSampler_Hetero_node_pairs_Data():
} }
) )
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1) negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
assert len(list(negative_dp)) == 5 assert len(list(negative_dp)) == 5
def test_NegativeSampler_Hetero_Data(): def test_NegativeSampler_Hetero_Data():
graph = get_hetero_graph() graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1:e1:n2": gb.ItemSet( "n1:e1:n2": gb.ItemSet(
...@@ -295,7 +326,9 @@ def test_NegativeSampler_Hetero_Data(): ...@@ -295,7 +326,9 @@ def test_NegativeSampler_Hetero_Data():
) )
batch_size = 2 batch_size = 2
negative_ratio = 1 negative_ratio = 1
item_sampler = gb.ItemSampler(itemset, batch_size=batch_size) item_sampler = gb.ItemSampler(itemset, batch_size=batch_size).copy_to(
F.ctx()
)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, negative_ratio) negative_dp = gb.UniformNegativeSampler(item_sampler, graph, negative_ratio)
assert len(list(negative_dp)) == 5 assert len(list(negative_dp)) == 5
# Perform negative sampling. # Perform negative sampling.
...@@ -311,5 +344,5 @@ def test_NegativeSampler_Hetero_Data(): ...@@ -311,5 +344,5 @@ def test_NegativeSampler_Hetero_Data():
for etype, seeds_data in data.seeds.items(): for etype, seeds_data in data.seeds.items():
neg_src = seeds_data[batch_size:, 0] neg_src = seeds_data[batch_size:, 0]
neg_dst = seeds_data[batch_size:, 1] neg_dst = seeds_data[batch_size:, 1]
assert torch.equal(expected_neg_src[i][etype], neg_src) assert torch.equal(expected_neg_src[i][etype].to(F.ctx()), neg_src)
assert (neg_dst < 3).all(), neg_dst assert (neg_dst < 3).all(), neg_dst
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment