Unverified Commit b743cdef authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Hack for original `seed_nodes` and `node_pairs`. (#7248)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 3df6e301
...@@ -62,6 +62,37 @@ def minibatcher_default(batch, names): ...@@ -62,6 +62,37 @@ def minibatcher_default(batch, names):
else: else:
init_data = {name: item for item, name in zip(batch, names)} init_data = {name: item for item, name in zip(batch, names)}
minibatch = MiniBatch() minibatch = MiniBatch()
# TODO(#7254): Hacks for original `seed_nodes` and `node_pairs`, which need
# to be cleaned up later.
if "node_pairs" in names:
pos_seeds = init_data["node_pairs"]
# Build negative graph.
if "negative_srcs" in names and "negative_dsts" in names:
neg_srcs = init_data["negative_srcs"]
neg_dsts = init_data["negative_dsts"]
(
init_data["seeds"],
init_data["labels"],
init_data["indexes"],
) = _construct_seeds(
pos_seeds, neg_srcs=neg_srcs, neg_dsts=neg_dsts
)
elif "negative_srcs" in names:
neg_srcs = init_data["negative_srcs"]
(
init_data["seeds"],
init_data["labels"],
init_data["indexes"],
) = _construct_seeds(pos_seeds, neg_srcs=neg_srcs)
elif "negative_dsts" in names:
neg_dsts = init_data["negative_dsts"]
(
init_data["seeds"],
init_data["labels"],
init_data["indexes"],
) = _construct_seeds(pos_seeds, neg_dsts=neg_dsts)
else:
init_data["seeds"] = pos_seeds
for name, item in init_data.items(): for name, item in init_data.items():
if not hasattr(minibatch, name): if not hasattr(minibatch, name):
dgl_warning( dgl_warning(
...@@ -69,13 +100,12 @@ def minibatcher_default(batch, names): ...@@ -69,13 +100,12 @@ def minibatcher_default(batch, names):
"`MiniBatch`. You probably need to provide a customized " "`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`." "`MiniBatcher`."
) )
if name == "node_pairs": # TODO(#7254): Hacks for original `seed_nodes` and `node_pairs`, which
# `node_pairs` is passed as a tensor in shape of `(N, 2)` and # need to be cleaned up later.
# should be converted to a tuple of `(src, dst)`. if name == "seed_nodes":
if isinstance(item, Mapping): name = "seeds"
item = {key: (item[key][:, 0], item[key][:, 1]) for key in item} if name in ("node_pairs", "negative_srcs", "negative_dsts"):
else: continue
item = (item[:, 0], item[:, 1])
setattr(minibatch, name, item) setattr(minibatch, name, item)
return minibatch return minibatch
...@@ -744,3 +774,80 @@ class DistributedItemSampler(ItemSampler): ...@@ -744,3 +774,80 @@ class DistributedItemSampler(ItemSampler):
) )
self._world_size = dist.get_world_size() self._world_size = dist.get_world_size()
self._rank = dist.get_rank() self._rank = dist.get_rank()
def _construct_seeds(pos_seeds, neg_srcs=None, neg_dsts=None):
# For homogeneous graph.
if isinstance(pos_seeds, torch.Tensor):
negative_ratio = neg_srcs.size(1) if neg_srcs else neg_dsts.size(1)
neg_srcs = (
neg_srcs
if neg_srcs is not None
else pos_seeds[:, 0].repeat_interleave(negative_ratio)
).view(-1)
neg_dsts = (
neg_dsts
if neg_dsts is not None
else pos_seeds[:, 1].repeat_interleave(negative_ratio)
).view(-1)
neg_seeds = torch.cat((neg_srcs, neg_dsts)).view(2, -1).T
seeds = torch.cat((pos_seeds, neg_seeds))
pos_seeds_num = pos_seeds.size(0)
labels = torch.empty(seeds.size(0), device=pos_seeds.device)
labels[:pos_seeds_num] = 1
labels[pos_seeds_num:] = 0
pos_indexes = torch.arange(
0,
pos_seeds_num,
device=pos_seeds.device,
)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes = torch.cat((pos_indexes, neg_indexes))
# For heterogeneous graph.
else:
negative_ratio = (
list(neg_srcs.values())[0].size(1)
if neg_srcs
else list(neg_dsts.values())[0].size(1)
)
seeds = {}
labels = {}
indexes = {}
for etype in pos_seeds:
neg_src = (
neg_srcs[etype]
if neg_srcs is not None
else pos_seeds[etype][:, 0].repeat_interleave(negative_ratio)
).view(-1)
neg_dst = (
neg_dsts[etype]
if neg_dsts is not None
else pos_seeds[etype][:, 1].repeat_interleave(negative_ratio)
).view(-1)
seeds[etype] = torch.cat(
(
pos_seeds[etype],
torch.cat(
(
neg_src,
neg_dst,
)
)
.view(2, -1)
.T,
)
)
pos_seeds_num = pos_seeds[etype].size(0)
labels[etype] = torch.empty(
seeds[etype].size(0), device=pos_seeds[etype].device
)
labels[etype][:pos_seeds_num] = 1
labels[etype][pos_seeds_num:] = 0
pos_indexes = torch.arange(
0,
pos_seeds_num,
device=pos_seeds[etype].device,
)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes[etype] = torch.cat((pos_indexes, neg_indexes))
return seeds, labels, indexes
...@@ -99,14 +99,14 @@ def test_InSubgraphSampler_homo(): ...@@ -99,14 +99,14 @@ def test_InSubgraphSampler_homo():
return _indices return _indices
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([0]).to(F.ctx())) assert torch.equal(mn.seeds, torch.LongTensor([0]).to(F.ctx()))
assert torch.equal( assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr, mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 3]).to(F.ctx()), torch.tensor([0, 3]).to(F.ctx()),
) )
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([5]).to(F.ctx())) assert torch.equal(mn.seeds, torch.LongTensor([5]).to(F.ctx()))
assert torch.equal( assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr, mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 2]).to(F.ctx()), torch.tensor([0, 2]).to(F.ctx()),
...@@ -114,7 +114,7 @@ def test_InSubgraphSampler_homo(): ...@@ -114,7 +114,7 @@ def test_InSubgraphSampler_homo():
assert torch.equal(original_indices(mn), torch.tensor([1, 4]).to(F.ctx())) assert torch.equal(original_indices(mn), torch.tensor([1, 4]).to(F.ctx()))
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([3]).to(F.ctx())) assert torch.equal(mn.seeds, torch.LongTensor([3]).to(F.ctx()))
assert torch.equal( assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr, mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 2]).to(F.ctx()), torch.tensor([0, 2]).to(F.ctx()),
...@@ -176,9 +176,7 @@ def test_InSubgraphSampler_hetero(): ...@@ -176,9 +176,7 @@ def test_InSubgraphSampler_hetero():
it = iter(in_subgraph_sampler) it = iter(in_subgraph_sampler)
mn = next(it) mn = next(it)
assert torch.equal( assert torch.equal(mn.seeds["N0"], torch.LongTensor([1, 0]).to(F.ctx()))
mn.seed_nodes["N0"], torch.LongTensor([1, 0]).to(F.ctx())
)
expected_sampled_csc = { expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase( "N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]), indptr=torch.LongTensor([0, 1, 3]),
...@@ -203,7 +201,7 @@ def test_InSubgraphSampler_hetero(): ...@@ -203,7 +201,7 @@ def test_InSubgraphSampler_hetero():
) )
mn = next(it) mn = next(it)
assert mn.seed_nodes == { assert mn.seeds == {
"N0": torch.LongTensor([2]).to(F.ctx()), "N0": torch.LongTensor([2]).to(F.ctx()),
"N1": torch.LongTensor([0]).to(F.ctx()), "N1": torch.LongTensor([0]).to(F.ctx()),
} }
...@@ -230,9 +228,7 @@ def test_InSubgraphSampler_hetero(): ...@@ -230,9 +228,7 @@ def test_InSubgraphSampler_hetero():
) )
mn = next(it) mn = next(it)
assert torch.equal( assert torch.equal(mn.seeds["N1"], torch.LongTensor([2, 1]).to(F.ctx()))
mn.seed_nodes["N1"], torch.LongTensor([2, 1]).to(F.ctx())
)
expected_sampled_csc = { expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase( "N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([]) indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
......
...@@ -95,9 +95,10 @@ def test_UniformNegativeSampler_node_pairs_invoke(): ...@@ -95,9 +95,10 @@ def test_UniformNegativeSampler_node_pairs_invoke():
def _verify(negative_sampler): def _verify(negative_sampler):
for data in negative_sampler: for data in negative_sampler:
# Assertation # Assertation
assert data.negative_srcs is None seeds_len = batch_size + batch_size * negative_ratio
assert data.negative_dsts.size(0) == batch_size assert data.seeds.size(0) == seeds_len
assert data.negative_dsts.size(1) == negative_ratio assert data.labels.size(0) == seeds_len
assert data.indexes.size(0) == seeds_len
# Invoke UniformNegativeSampler via class constructor. # Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
...@@ -137,14 +138,30 @@ def test_Uniform_NegativeSampler_node_pairs(negative_ratio): ...@@ -137,14 +138,30 @@ def test_Uniform_NegativeSampler_node_pairs(negative_ratio):
) )
# Perform Negative sampling. # Perform Negative sampling.
for data in negative_sampler: for data in negative_sampler:
pos_src, pos_dst = data.node_pairs expected_labels = torch.empty(
neg_src, neg_dst = data.negative_srcs, data.negative_dsts batch_size * (negative_ratio + 1), device=F.ctx()
)
expected_labels[:batch_size] = 1
expected_labels[batch_size:] = 0
expected_indexes = torch.arange(batch_size, device=F.ctx())
expected_indexes = torch.cat(
(
expected_indexes,
expected_indexes.repeat_interleave(negative_ratio),
)
)
expected_neg_src = data.seeds[:batch_size][:, 0].repeat_interleave(
negative_ratio
)
# Assertation # Assertation
assert len(pos_src) == batch_size assert data.negative_srcs is None
assert len(pos_dst) == batch_size assert data.negative_dsts is None
assert len(neg_dst) == batch_size assert data.labels is not None
assert neg_src is None assert data.indexes is not None
assert neg_dst.numel() == batch_size * negative_ratio assert data.seeds.size(0) == batch_size * (negative_ratio + 1)
assert torch.equal(data.labels, expected_labels)
assert torch.equal(data.indexes, expected_indexes)
assert torch.equal(data.seeds[batch_size:][:, 0], expected_neg_src)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
......
...@@ -15,18 +15,18 @@ from . import gb_test_utils ...@@ -15,18 +15,18 @@ from . import gb_test_utils
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") @unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo(): def test_CopyTo():
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
gb.ItemSet(torch.arange(20), names="seed_nodes"), 4 gb.ItemSet(torch.arange(20), names="seeds"), 4
) )
# Invoke CopyTo via class constructor. # Invoke CopyTo via class constructor.
dp = gb.CopyTo(item_sampler, "cuda") dp = gb.CopyTo(item_sampler, "cuda")
for data in dp: for data in dp:
assert data.seed_nodes.device.type == "cuda" assert data.seeds.device.type == "cuda"
# Invoke CopyTo via functional form. # Invoke CopyTo via functional form.
dp = item_sampler.copy_to("cuda") dp = item_sampler.copy_to("cuda")
for data in dp: for data in dp:
assert data.seed_nodes.device.type == "cuda" assert data.seeds.device.type == "cuda"
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -37,7 +37,6 @@ def test_CopyTo(): ...@@ -37,7 +37,6 @@ def test_CopyTo():
"link_prediction", "link_prediction",
"edge_classification", "edge_classification",
"extra_attrs", "extra_attrs",
"other",
], ],
) )
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") @unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
...@@ -63,11 +62,6 @@ def test_CopyToWithMiniBatches_original(task): ...@@ -63,11 +62,6 @@ def test_CopyToWithMiniBatches_original(task):
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)), (torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("node_pairs", "labels"), names=("node_pairs", "labels"),
) )
else:
itemset = gb.ItemSet(
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("node_pairs", "seed_nodes"),
)
graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True)
features = {} features = {}
...@@ -96,38 +90,25 @@ def test_CopyToWithMiniBatches_original(task): ...@@ -96,38 +90,25 @@ def test_CopyToWithMiniBatches_original(task):
"sampled_subgraphs", "sampled_subgraphs",
"labels", "labels",
"blocks", "blocks",
"seeds",
] ]
elif task == "node_inference": elif task == "node_inference":
copied_attrs = [ copied_attrs = [
"seed_nodes", "seeds",
"sampled_subgraphs", "sampled_subgraphs",
"blocks", "blocks",
"labels", "labels",
] ]
elif task == "link_prediction": elif task == "link_prediction" or task == "edge_classification":
copied_attrs = [ copied_attrs = [
"compacted_node_pairs", "labels",
"node_features", "compacted_seeds",
"edge_features",
"sampled_subgraphs", "sampled_subgraphs",
"compacted_negative_srcs", "indexes",
"compacted_negative_dsts",
"blocks",
"positive_node_pairs",
"negative_node_pairs",
"node_pairs_with_labels",
]
elif task == "edge_classification":
copied_attrs = [
"compacted_node_pairs",
"node_features", "node_features",
"edge_features", "edge_features",
"sampled_subgraphs",
"labels",
"blocks", "blocks",
"positive_node_pairs", "seeds",
"negative_node_pairs",
"node_pairs_with_labels",
] ]
elif task == "extra_attrs": elif task == "extra_attrs":
copied_attrs = [ copied_attrs = [
...@@ -137,6 +118,7 @@ def test_CopyToWithMiniBatches_original(task): ...@@ -137,6 +118,7 @@ def test_CopyToWithMiniBatches_original(task):
"labels", "labels",
"blocks", "blocks",
"seed_nodes", "seed_nodes",
"seeds",
] ]
def test_data_device(datapipe): def test_data_device(datapipe):
......
...@@ -78,7 +78,7 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -78,7 +78,7 @@ def test_FeatureFetcher_with_edges_homo():
) )
def add_node_and_edge_ids(minibatch): def add_node_and_edge_ids(minibatch):
seeds = minibatch.seed_nodes seeds = minibatch.seeds
subgraphs = [] subgraphs = []
for _ in range(3): for _ in range(3):
sampled_csc = gb.CSCFormatBase( sampled_csc = gb.CSCFormatBase(
...@@ -172,7 +172,7 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -172,7 +172,7 @@ def test_FeatureFetcher_with_edges_hetero():
b = torch.tensor([[random.randint(0, 10)] for _ in range(50)]) b = torch.tensor([[random.randint(0, 10)] for _ in range(50)])
def add_node_and_edge_ids(minibatch): def add_node_and_edge_ids(minibatch):
seeds = minibatch.seed_nodes seeds = minibatch.seeds
subgraphs = [] subgraphs = []
original_edge_ids = { original_edge_ids = {
"n1:e1:n2": torch.randint(0, 50, (10,)), "n1:e1:n2": torch.randint(0, 50, (10,)),
......
...@@ -60,69 +60,85 @@ def test_integration_link_prediction(): ...@@ -60,69 +60,85 @@ def test_integration_link_prediction():
) )
expected = [ expected = [
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=tensor([[5, 1],
[3, 2],
[3, 2],
[3, 3],
[5, 0],
[5, 0],
[3, 3],
[3, 0],
[3, 5],
[3, 3],
[3, 3],
[3, 4]]),
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 2, 2, 2, 3], dtype=torch.int32),
indices=tensor([0, 4], dtype=torch.int32), indices=tensor([0, 5, 4], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_row_node_ids=tensor([5, 1, 3, 2, 0, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
indices=tensor([5, 4], dtype=torch.int32), indices=tensor([5, 4], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_row_node_ids=tensor([5, 1, 3, 2, 0, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 1]), positive_node_pairs=None,
tensor([2, 3, 3, 1])), node_pairs_with_labels=None,
node_pairs_with_labels=((tensor([0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1]), tensor([2, 3, 3, 1, 4, 4, 1, 4, 0, 1, 1, 5])), node_pairs=None,
tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])),
node_pairs=(tensor([5, 3, 3, 3]),
tensor([1, 2, 2, 3])),
node_features={'feat': tensor([[0.5160, 0.2486], node_features={'feat': tensor([[0.5160, 0.2486],
[0.8672, 0.2276],
[0.6172, 0.7865], [0.6172, 0.7865],
[0.8672, 0.2276],
[0.2109, 0.1089], [0.2109, 0.1089],
[0.9634, 0.2294], [0.9634, 0.2294],
[0.5503, 0.8223]])}, [0.5503, 0.8223]])},
negative_srcs=None, negative_srcs=None,
negative_node_pairs=(tensor([[0, 0], negative_node_pairs=None,
[1, 1], negative_dsts=None,
[1, 1], labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
[1, 1]]), input_nodes=tensor([5, 1, 3, 2, 0, 4]),
tensor([[4, 4], indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
[1, 4],
[0, 1],
[1, 5]])),
negative_dsts=tensor([[0, 0],
[3, 0],
[5, 3],
[3, 4]]),
labels=None,
input_nodes=tensor([5, 3, 1, 2, 0, 4]),
indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None, compacted_seeds=tensor([[0, 1],
compacted_node_pairs=(tensor([0, 1, 1, 1]), [2, 3],
tensor([2, 3, 3, 1])), [2, 3],
[2, 2],
[0, 4],
[0, 4],
[2, 2],
[2, 4],
[2, 0],
[2, 2],
[2, 2],
[2, 5]]),
compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_srcs=None,
compacted_negative_dsts=tensor([[4, 4], compacted_negative_dsts=None,
[1, 4], blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
[0, 1],
[1, 5]]),
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)], Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],
)""" )"""
), ),
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=tensor([[3, 3],
[4, 3],
[4, 4],
[0, 4],
[3, 1],
[3, 5],
[4, 2],
[4, 5],
[4, 4],
[4, 3],
[0, 1],
[0, 5]]),
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([4, 1, 0], dtype=torch.int32), indices=tensor([4, 0], dtype=torch.int32),
), ),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]), original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_edge_ids=None, original_edge_ids=None,
...@@ -135,12 +151,9 @@ def test_integration_link_prediction(): ...@@ -135,12 +151,9 @@ def test_integration_link_prediction():
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]), original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 2]), positive_node_pairs=None,
tensor([0, 0, 1, 1])), node_pairs_with_labels=None,
node_pairs_with_labels=((tensor([0, 1, 1, 2, 0, 0, 1, 1, 1, 1, 2, 2]), tensor([0, 0, 1, 1, 3, 4, 5, 4, 1, 0, 3, 4])), node_pairs=None,
tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])),
node_pairs=(tensor([3, 4, 4, 0]),
tensor([3, 3, 4, 4])),
node_features={'feat': tensor([[0.8672, 0.2276], node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223], [0.5503, 0.8223],
[0.9634, 0.2294], [0.9634, 0.2294],
...@@ -148,37 +161,39 @@ def test_integration_link_prediction(): ...@@ -148,37 +161,39 @@ def test_integration_link_prediction():
[0.5160, 0.2486], [0.5160, 0.2486],
[0.2109, 0.1089]])}, [0.2109, 0.1089]])},
negative_srcs=None, negative_srcs=None,
negative_node_pairs=(tensor([[0, 0], negative_node_pairs=None,
[1, 1], negative_dsts=None,
[1, 1], labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
[2, 2]]),
tensor([[3, 4],
[5, 4],
[1, 0],
[3, 4]])),
negative_dsts=tensor([[1, 5],
[2, 5],
[4, 3],
[1, 5]]),
labels=None,
input_nodes=tensor([3, 4, 0, 1, 5, 2]), input_nodes=tensor([3, 4, 0, 1, 5, 2]),
indexes=None, indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None, compacted_seeds=tensor([[0, 0],
compacted_node_pairs=(tensor([0, 1, 1, 2]), [1, 0],
tensor([0, 0, 1, 1])), [1, 1],
[2, 1],
[0, 3],
[0, 4],
[1, 5],
[1, 4],
[1, 1],
[1, 0],
[2, 3],
[2, 4]]),
compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_srcs=None,
compacted_negative_dsts=tensor([[3, 4], compacted_negative_dsts=None,
[5, 4], blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
[1, 0],
[3, 4]]),
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)], Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)],
)""" )"""
), ),
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=tensor([[5, 5],
[4, 5],
[5, 0],
[5, 4],
[4, 0],
[4, 1]]),
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32), indices=tensor([1, 0], dtype=torch.int32),
...@@ -194,34 +209,30 @@ def test_integration_link_prediction(): ...@@ -194,34 +209,30 @@ def test_integration_link_prediction():
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4, 0, 1]), original_column_node_ids=tensor([5, 4, 0, 1]),
)], )],
positive_node_pairs=(tensor([0, 1]), positive_node_pairs=None,
tensor([0, 0])), node_pairs_with_labels=None,
node_pairs_with_labels=((tensor([0, 1, 0, 0, 1, 1]), tensor([0, 0, 2, 1, 2, 3])), node_pairs=None,
tensor([1., 1., 0., 0., 0., 0.])),
node_pairs=(tensor([5, 4]),
tensor([5, 5])),
node_features={'feat': tensor([[0.5160, 0.2486], node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223], [0.5503, 0.8223],
[0.9634, 0.2294], [0.9634, 0.2294],
[0.6172, 0.7865]])}, [0.6172, 0.7865]])},
negative_srcs=None, negative_srcs=None,
negative_node_pairs=(tensor([[0, 0], negative_node_pairs=None,
[1, 1]]), negative_dsts=None,
tensor([[2, 1], labels=tensor([1., 1., 0., 0., 0., 0.]),
[2, 3]])),
negative_dsts=tensor([[0, 4],
[0, 1]]),
labels=None,
input_nodes=tensor([5, 4, 0, 1]), input_nodes=tensor([5, 4, 0, 1]),
indexes=None, indexes=tensor([0, 1, 0, 0, 1, 1]),
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None, compacted_seeds=tensor([[0, 0],
compacted_node_pairs=(tensor([0, 1]), [1, 0],
tensor([0, 0])), [0, 2],
[0, 1],
[1, 2],
[1, 3]]),
compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_srcs=None,
compacted_negative_dsts=tensor([[2, 1], compacted_negative_dsts=None,
[2, 3]]),
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2), blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)], Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)],
)""" )"""
...@@ -285,43 +296,46 @@ def test_integration_node_classification(): ...@@ -285,43 +296,46 @@ def test_integration_node_classification():
) )
expected = [ expected = [
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=tensor([[5, 1],
[3, 2],
[3, 2],
[3, 3]]),
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([4, 1, 0, 1], dtype=torch.int32), indices=tensor([4, 0, 2, 2], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 3, 1, 2, 4]), original_row_node_ids=tensor([5, 1, 3, 2, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]), original_column_node_ids=tensor([5, 1, 3, 2]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([0, 1, 0, 1], dtype=torch.int32), indices=tensor([0, 0, 2, 2], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 3, 1, 2]), original_row_node_ids=tensor([5, 1, 3, 2]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]), original_column_node_ids=tensor([5, 1, 3, 2]),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 1]), positive_node_pairs=None,
tensor([2, 3, 3, 1])),
node_pairs_with_labels=None, node_pairs_with_labels=None,
node_pairs=(tensor([5, 3, 3, 3]), node_pairs=None,
tensor([1, 2, 2, 3])),
node_features={'feat': tensor([[0.5160, 0.2486], node_features={'feat': tensor([[0.5160, 0.2486],
[0.8672, 0.2276],
[0.6172, 0.7865], [0.6172, 0.7865],
[0.8672, 0.2276],
[0.2109, 0.1089], [0.2109, 0.1089],
[0.5503, 0.8223]])}, [0.5503, 0.8223]])},
negative_srcs=None, negative_srcs=None,
negative_node_pairs=None, negative_node_pairs=None,
negative_dsts=None, negative_dsts=None,
labels=None, labels=None,
input_nodes=tensor([5, 3, 1, 2, 4]), input_nodes=tensor([5, 1, 3, 2, 4]),
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None, compacted_seeds=tensor([[0, 1],
compacted_node_pairs=(tensor([0, 1, 1, 1]), [2, 3],
tensor([2, 3, 3, 1])), [2, 3],
[2, 2]]),
compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None, compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4), blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
...@@ -329,7 +343,10 @@ def test_integration_node_classification(): ...@@ -329,7 +343,10 @@ def test_integration_node_classification():
)""" )"""
), ),
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=tensor([[3, 3],
[4, 3],
[4, 4],
[0, 4]]),
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32), indices=tensor([0, 2], dtype=torch.int32),
...@@ -345,11 +362,9 @@ def test_integration_node_classification(): ...@@ -345,11 +362,9 @@ def test_integration_node_classification():
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]), original_column_node_ids=tensor([3, 4, 0]),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 2]), positive_node_pairs=None,
tensor([0, 0, 1, 1])),
node_pairs_with_labels=None, node_pairs_with_labels=None,
node_pairs=(tensor([3, 4, 4, 0]), node_pairs=None,
tensor([3, 3, 4, 4])),
node_features={'feat': tensor([[0.8672, 0.2276], node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223], [0.5503, 0.8223],
[0.9634, 0.2294]])}, [0.9634, 0.2294]])},
...@@ -361,9 +376,11 @@ def test_integration_node_classification(): ...@@ -361,9 +376,11 @@ def test_integration_node_classification():
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None, compacted_seeds=tensor([[0, 0],
compacted_node_pairs=(tensor([0, 1, 1, 2]), [1, 0],
tensor([0, 0, 1, 1])), [1, 1],
[2, 1]]),
compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None, compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2), blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
...@@ -371,7 +388,8 @@ def test_integration_node_classification(): ...@@ -371,7 +388,8 @@ def test_integration_node_classification():
)""" )"""
), ),
str( str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=tensor([[5, 5],
[4, 5]]),
seed_nodes=None, seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32), indices=tensor([0, 2], dtype=torch.int32),
...@@ -387,11 +405,9 @@ def test_integration_node_classification(): ...@@ -387,11 +405,9 @@ def test_integration_node_classification():
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([5, 4]),
)], )],
positive_node_pairs=(tensor([0, 1]), positive_node_pairs=None,
tensor([0, 0])),
node_pairs_with_labels=None, node_pairs_with_labels=None,
node_pairs=(tensor([5, 4]), node_pairs=None,
tensor([5, 5])),
node_features={'feat': tensor([[0.5160, 0.2486], node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223], [0.5503, 0.8223],
[0.9634, 0.2294]])}, [0.9634, 0.2294]])},
...@@ -403,9 +419,9 @@ def test_integration_node_classification(): ...@@ -403,9 +419,9 @@ def test_integration_node_classification():
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None, compacted_seeds=tensor([[0, 0],
compacted_node_pairs=(tensor([0, 1]), [1, 0]]),
tensor([0, 0])), compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None, compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2), blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
......
...@@ -49,20 +49,20 @@ def test_ItemSampler_minibatcher(): ...@@ -49,20 +49,20 @@ def test_ItemSampler_minibatcher():
item_sampler = gb.ItemSampler(item_set, batch_size=4) item_sampler = gb.ItemSampler(item_set, batch_size=4)
minibatch = next(iter(item_sampler)) minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None assert minibatch.seeds is not None
assert len(minibatch.seed_nodes) == 4 assert len(minibatch.seeds) == 4
# Customized minibatcher is used if specified. # Customized minibatcher is used if specified.
def minibatcher(batch, names): def minibatcher(batch, names):
return gb.MiniBatch(seed_nodes=batch) return gb.MiniBatch(seeds=batch)
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=4, minibatcher=minibatcher item_set, batch_size=4, minibatcher=minibatcher
) )
minibatch = next(iter(item_sampler)) minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None assert minibatch.seeds is not None
assert len(minibatch.seed_nodes) == 4 assert len(minibatch.seeds) == 4
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
...@@ -83,17 +83,17 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last): ...@@ -83,17 +83,17 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
minibatch_ids = [] minibatch_ids = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None assert minibatch.seeds is not None
assert minibatch.labels is None assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size assert len(minibatch.seeds) == batch_size
else: else:
if not drop_last: if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size assert len(minibatch.seeds) == num_ids % batch_size
else: else:
assert False assert False
minibatch_ids.append(minibatch.seed_nodes) minibatch_ids.append(minibatch.seeds)
minibatch_ids = torch.cat(minibatch_ids) minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
...@@ -111,17 +111,17 @@ def test_ItemSet_integer(batch_size, shuffle, drop_last): ...@@ -111,17 +111,17 @@ def test_ItemSet_integer(batch_size, shuffle, drop_last):
minibatch_ids = [] minibatch_ids = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None assert minibatch.seeds is not None
assert minibatch.labels is None assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size assert len(minibatch.seeds) == batch_size
else: else:
if not drop_last: if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size assert len(minibatch.seeds) == num_ids % batch_size
else: else:
assert False assert False
minibatch_ids.append(minibatch.seed_nodes) minibatch_ids.append(minibatch.seeds)
minibatch_ids = torch.cat(minibatch_ids) minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
...@@ -140,17 +140,17 @@ def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last): ...@@ -140,17 +140,17 @@ def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last):
minibatch_ids = [] minibatch_ids = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None assert minibatch.seeds is not None
assert minibatch.labels is None assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size assert len(minibatch.seeds) == batch_size
else: else:
if not drop_last: if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size assert len(minibatch.seeds) == num_ids % batch_size
else: else:
assert False assert False
minibatch_ids.append(minibatch.seed_nodes) minibatch_ids.append(minibatch.seeds)
minibatch_ids = torch.cat(minibatch_ids) minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
...@@ -171,18 +171,18 @@ def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last): ...@@ -171,18 +171,18 @@ def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):
minibatch_labels = [] minibatch_labels = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None assert minibatch.seeds is not None
assert minibatch.labels is not None assert minibatch.labels is not None
assert len(minibatch.seed_nodes) == len(minibatch.labels) assert len(minibatch.seeds) == len(minibatch.labels)
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size assert len(minibatch.seeds) == batch_size
else: else:
if not drop_last: if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size assert len(minibatch.seeds) == num_ids % batch_size
else: else:
assert False assert False
minibatch_ids.append(minibatch.seed_nodes) minibatch_ids.append(minibatch.seeds)
minibatch_labels.append(minibatch.labels) minibatch_labels.append(minibatch.labels)
minibatch_ids = torch.cat(minibatch_ids) minibatch_ids = torch.cat(minibatch_ids)
minibatch_labels = torch.cat(minibatch_labels) minibatch_labels = torch.cat(minibatch_labels)
...@@ -254,10 +254,10 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last): ...@@ -254,10 +254,10 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None assert minibatch.seeds is not None
assert isinstance(minibatch.node_pairs, tuple) assert isinstance(minibatch.seeds, torch.Tensor)
assert minibatch.labels is None assert minibatch.labels is None
src, dst = minibatch.node_pairs src, dst = minibatch.seeds.T
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -295,10 +295,10 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -295,10 +295,10 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
dst_ids = [] dst_ids = []
labels = [] labels = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None assert minibatch.seeds is not None
assert isinstance(minibatch.node_pairs, tuple) assert isinstance(minibatch.seeds, torch.Tensor)
assert minibatch.labels is not None assert minibatch.labels is not None
src, dst = minibatch.node_pairs src, dst = minibatch.seeds.T
label = minibatch.labels label = minibatch.labels
assert len(src) == len(dst) assert len(src) == len(dst)
assert len(src) == len(label) assert len(src) == len(label)
...@@ -349,11 +349,13 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -349,11 +349,13 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
dst_ids = [] dst_ids = []
negs_ids = [] negs_ids = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None assert minibatch.seeds is not None
assert isinstance(minibatch.node_pairs, tuple) assert isinstance(minibatch.seeds, torch.Tensor)
assert minibatch.negative_dsts is not None assert minibatch.labels is not None
src, dst = minibatch.node_pairs assert minibatch.indexes is not None
negs = minibatch.negative_dsts src, dst = minibatch.seeds.T
negs_src = src[~minibatch.labels.to(bool)]
negs_dst = dst[~minibatch.labels.to(bool)]
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -362,25 +364,32 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -362,25 +364,32 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
expected_batch_size = num_ids % batch_size expected_batch_size = num_ids % batch_size
else: else:
assert False assert False
assert len(src) == expected_batch_size assert len(src) == expected_batch_size * 3
assert len(dst) == expected_batch_size assert len(dst) == expected_batch_size * 3
assert negs.dim() == 2 assert negs_src.dim() == 1
assert negs.shape[0] == expected_batch_size assert negs_dst.dim() == 1
assert negs.shape[1] == num_negs assert len(negs_src) == expected_batch_size * 2
assert len(negs_dst) == expected_batch_size * 2
expected_indexes = torch.arange(expected_batch_size)
expected_indexes = torch.cat(
(expected_indexes, expected_indexes.repeat_interleave(2))
)
assert torch.equal(minibatch.indexes, expected_indexes)
# Verify node pairs and negative destinations. # Verify node pairs and negative destinations.
assert torch.equal(src + 1, dst) assert torch.equal(
assert torch.equal(negs[:, 0] + 1, negs[:, 1]) src[minibatch.labels.to(bool)] + 1, dst[minibatch.labels.to(bool)]
)
assert torch.equal((negs_dst - 2 * num_ids) // 2 * 2, negs_src)
# Archive batch. # Archive batch.
src_ids.append(src) src_ids.append(src)
dst_ids.append(dst) dst_ids.append(dst)
negs_ids.append(negs) negs_ids.append(negs_dst)
src_ids = torch.cat(src_ids) src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids) dst_ids = torch.cat(dst_ids)
negs_ids = torch.cat(negs_ids) negs_ids = torch.cat(negs_ids)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1, 0] <= negs_ids[1:, 0]) is not shuffle assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1, 1] <= negs_ids[1:, 1]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
...@@ -472,7 +481,7 @@ def test_append_with_other_datapipes(): ...@@ -472,7 +481,7 @@ def test_append_with_other_datapipes():
data_pipe = data_pipe.enumerate() data_pipe = data_pipe.enumerate()
for i, (idx, data) in enumerate(data_pipe): for i, (idx, data) in enumerate(data_pipe):
assert i == idx assert i == idx
assert len(data.seed_nodes) == batch_size assert len(data.seeds) == batch_size
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
...@@ -510,9 +519,9 @@ def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last): ...@@ -510,9 +519,9 @@ def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last):
else: else:
assert False assert False
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None assert minibatch.seeds is not None
ids = [] ids = []
for _, v in minibatch.seed_nodes.items(): for _, v in minibatch.seeds.items():
ids.append(v) ids.append(v)
ids = torch.cat(ids) ids = torch.cat(ids)
assert len(ids) == expected_batch_size assert len(ids) == expected_batch_size
...@@ -549,9 +558,9 @@ def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last): ...@@ -549,9 +558,9 @@ def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last):
else: else:
assert False assert False
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None assert minibatch.seeds is not None
ids = [] ids = []
for _, v in minibatch.seed_nodes.items(): for _, v in minibatch.seeds.items():
ids.append(v) ids.append(v)
ids = torch.cat(ids) ids = torch.cat(ids)
assert len(ids) == expected_batch_size assert len(ids) == expected_batch_size
...@@ -587,7 +596,7 @@ def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last): ...@@ -587,7 +596,7 @@ def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
minibatch_labels = [] minibatch_labels = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None assert minibatch.seeds is not None
assert minibatch.labels is not None assert minibatch.labels is not None
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
...@@ -598,7 +607,7 @@ def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last): ...@@ -598,7 +607,7 @@ def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
else: else:
assert False assert False
ids = [] ids = []
for _, v in minibatch.seed_nodes.items(): for _, v in minibatch.seeds.items():
ids.append(v) ids.append(v)
ids = torch.cat(ids) ids = torch.cat(ids)
assert len(ids) == expected_batch_size assert len(ids) == expected_batch_size
...@@ -638,7 +647,7 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last): ...@@ -638,7 +647,7 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
dst_ids = [] dst_ids = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None assert minibatch.seeds is not None
assert minibatch.labels is None assert minibatch.labels is None
is_last = (i + 1) * batch_size >= total_pairs is_last = (i + 1) * batch_size >= total_pairs
if not is_last or total_pairs % batch_size == 0: if not is_last or total_pairs % batch_size == 0:
...@@ -650,10 +659,10 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last): ...@@ -650,10 +659,10 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
assert False assert False
src = [] src = []
dst = [] dst = []
for _, (node_pairs) in minibatch.node_pairs.items(): for _, (seeds) in minibatch.seeds.items():
assert isinstance(node_pairs, tuple) assert isinstance(seeds, torch.Tensor)
src.append(node_pairs[0]) src.append(seeds[:, 0])
dst.append(node_pairs[1]) dst.append(seeds[:, 1])
src = torch.cat(src) src = torch.cat(src)
dst = torch.cat(dst) dst = torch.cat(dst)
assert len(src) == expected_batch_size assert len(src) == expected_batch_size
...@@ -696,8 +705,9 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -696,8 +705,9 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
labels = [] labels = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None assert minibatch.seeds is not None
assert minibatch.labels is not None assert minibatch.labels is not None
assert minibatch.negative_dsts is None
is_last = (i + 1) * batch_size >= total_ids is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0: if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -709,10 +719,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -709,10 +719,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src = [] src = []
dst = [] dst = []
label = [] label = []
for _, node_pairs in minibatch.node_pairs.items(): for _, seeds in minibatch.seeds.items():
assert isinstance(node_pairs, tuple) assert isinstance(seeds, torch.Tensor)
src.append(node_pairs[0]) src.append(seeds[:, 0])
dst.append(node_pairs[1]) dst.append(seeds[:, 1])
for _, v_label in minibatch.labels.items(): for _, v_label in minibatch.labels.items():
label.append(v_label) label.append(v_label)
src = torch.cat(src) src = torch.cat(src)
...@@ -769,8 +779,9 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -769,8 +779,9 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
negs_ids = [] negs_ids = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None assert minibatch.seeds is not None
assert minibatch.negative_dsts is not None assert minibatch.labels is not None
assert minibatch.negative_dsts is None
is_last = (i + 1) * batch_size >= total_ids is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0: if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -781,33 +792,37 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -781,33 +792,37 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert False assert False
src = [] src = []
dst = [] dst = []
negs = [] negs_src = []
for _, node_pairs in minibatch.node_pairs.items(): negs_dst = []
assert isinstance(node_pairs, tuple) for etype, seeds in minibatch.seeds.items():
src.append(node_pairs[0]) assert isinstance(seeds, torch.Tensor)
dst.append(node_pairs[1]) src_etype = seeds[:, 0]
for _, v_negs in minibatch.negative_dsts.items(): dst_etype = seeds[:, 1]
negs.append(v_negs) src.append(src_etype[minibatch.labels[etype].to(bool)])
dst.append(dst_etype[minibatch.labels[etype].to(bool)])
negs_src.append(src_etype[~minibatch.labels[etype].to(bool)])
negs_dst.append(dst_etype[~minibatch.labels[etype].to(bool)])
src = torch.cat(src) src = torch.cat(src)
dst = torch.cat(dst) dst = torch.cat(dst)
negs = torch.cat(negs) negs_src = torch.cat(negs_src)
negs_dst = torch.cat(negs_dst)
assert len(src) == expected_batch_size assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size assert len(dst) == expected_batch_size
assert len(negs) == expected_batch_size assert len(negs_src) == expected_batch_size * 2
assert len(negs_dst) == expected_batch_size * 2
src_ids.append(src) src_ids.append(src)
dst_ids.append(dst) dst_ids.append(dst)
negs_ids.append(negs) negs_ids.append(negs_dst)
assert negs.dim() == 2 assert negs_src.dim() == 1
assert negs.shape[0] == expected_batch_size assert negs_dst.dim() == 1
assert negs.shape[1] == num_negs
assert torch.equal(src + 1, dst) assert torch.equal(src + 1, dst)
assert torch.equal(negs[:, 0] + 1, negs[:, 1]) assert torch.equal(negs_src, (negs_dst - num_ids * 4) // 2 * 2)
src_ids = torch.cat(src_ids) src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids) dst_ids = torch.cat(dst_ids)
negs_ids = torch.cat(negs_ids) negs_ids = torch.cat(negs_ids)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle assert torch.all(negs_ids <= negs_ids) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
...@@ -961,10 +976,10 @@ def distributed_item_sampler_subprocess( ...@@ -961,10 +976,10 @@ def distributed_item_sampler_subprocess(
sampled_count = torch.zeros(num_ids, dtype=torch.int32) sampled_count = torch.zeros(num_ids, dtype=torch.int32)
for i in data_loader: for i in data_loader:
# Count how many times each item is sampled. # Count how many times each item is sampled.
sampled_count[i.seed_nodes] += 1 sampled_count[i.seeds] += 1
if drop_last: if drop_last:
assert i.seed_nodes.size(0) == batch_size assert i.seeds.size(0) == batch_size
num_items += i.seed_nodes.size(0) num_items += i.seeds.size(0)
num_batches = len(list(item_sampler)) num_batches = len(list(item_sampler))
if drop_uneven_inputs: if drop_uneven_inputs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment