"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "0785dba4df988119955b5380877e50d134416101"
Unverified Commit b743cdef authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Hack for original `seed_nodes` and `node_pairs`. (#7248)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 3df6e301
......@@ -62,6 +62,37 @@ def minibatcher_default(batch, names):
else:
init_data = {name: item for item, name in zip(batch, names)}
minibatch = MiniBatch()
# TODO(#7254): Hacks for original `seed_nodes` and `node_pairs`, which need
# to be cleaned up later.
if "node_pairs" in names:
pos_seeds = init_data["node_pairs"]
# Build negative graph.
if "negative_srcs" in names and "negative_dsts" in names:
neg_srcs = init_data["negative_srcs"]
neg_dsts = init_data["negative_dsts"]
(
init_data["seeds"],
init_data["labels"],
init_data["indexes"],
) = _construct_seeds(
pos_seeds, neg_srcs=neg_srcs, neg_dsts=neg_dsts
)
elif "negative_srcs" in names:
neg_srcs = init_data["negative_srcs"]
(
init_data["seeds"],
init_data["labels"],
init_data["indexes"],
) = _construct_seeds(pos_seeds, neg_srcs=neg_srcs)
elif "negative_dsts" in names:
neg_dsts = init_data["negative_dsts"]
(
init_data["seeds"],
init_data["labels"],
init_data["indexes"],
) = _construct_seeds(pos_seeds, neg_dsts=neg_dsts)
else:
init_data["seeds"] = pos_seeds
for name, item in init_data.items():
if not hasattr(minibatch, name):
dgl_warning(
......@@ -69,13 +100,12 @@ def minibatcher_default(batch, names):
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
)
if name == "node_pairs":
# `node_pairs` is passed as a tensor in shape of `(N, 2)` and
# should be converted to a tuple of `(src, dst)`.
if isinstance(item, Mapping):
item = {key: (item[key][:, 0], item[key][:, 1]) for key in item}
else:
item = (item[:, 0], item[:, 1])
# TODO(#7254): Hacks for original `seed_nodes` and `node_pairs`, which
# need to be cleaned up later.
if name == "seed_nodes":
name = "seeds"
if name in ("node_pairs", "negative_srcs", "negative_dsts"):
continue
setattr(minibatch, name, item)
return minibatch
......@@ -744,3 +774,80 @@ class DistributedItemSampler(ItemSampler):
)
self._world_size = dist.get_world_size()
self._rank = dist.get_rank()
def _construct_seeds(pos_seeds, neg_srcs=None, neg_dsts=None):
# For homogeneous graph.
if isinstance(pos_seeds, torch.Tensor):
negative_ratio = neg_srcs.size(1) if neg_srcs else neg_dsts.size(1)
neg_srcs = (
neg_srcs
if neg_srcs is not None
else pos_seeds[:, 0].repeat_interleave(negative_ratio)
).view(-1)
neg_dsts = (
neg_dsts
if neg_dsts is not None
else pos_seeds[:, 1].repeat_interleave(negative_ratio)
).view(-1)
neg_seeds = torch.cat((neg_srcs, neg_dsts)).view(2, -1).T
seeds = torch.cat((pos_seeds, neg_seeds))
pos_seeds_num = pos_seeds.size(0)
labels = torch.empty(seeds.size(0), device=pos_seeds.device)
labels[:pos_seeds_num] = 1
labels[pos_seeds_num:] = 0
pos_indexes = torch.arange(
0,
pos_seeds_num,
device=pos_seeds.device,
)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes = torch.cat((pos_indexes, neg_indexes))
# For heterogeneous graph.
else:
negative_ratio = (
list(neg_srcs.values())[0].size(1)
if neg_srcs
else list(neg_dsts.values())[0].size(1)
)
seeds = {}
labels = {}
indexes = {}
for etype in pos_seeds:
neg_src = (
neg_srcs[etype]
if neg_srcs is not None
else pos_seeds[etype][:, 0].repeat_interleave(negative_ratio)
).view(-1)
neg_dst = (
neg_dsts[etype]
if neg_dsts is not None
else pos_seeds[etype][:, 1].repeat_interleave(negative_ratio)
).view(-1)
seeds[etype] = torch.cat(
(
pos_seeds[etype],
torch.cat(
(
neg_src,
neg_dst,
)
)
.view(2, -1)
.T,
)
)
pos_seeds_num = pos_seeds[etype].size(0)
labels[etype] = torch.empty(
seeds[etype].size(0), device=pos_seeds[etype].device
)
labels[etype][:pos_seeds_num] = 1
labels[etype][pos_seeds_num:] = 0
pos_indexes = torch.arange(
0,
pos_seeds_num,
device=pos_seeds[etype].device,
)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes[etype] = torch.cat((pos_indexes, neg_indexes))
return seeds, labels, indexes
......@@ -99,14 +99,14 @@ def test_InSubgraphSampler_homo():
return _indices
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([0]).to(F.ctx()))
assert torch.equal(mn.seeds, torch.LongTensor([0]).to(F.ctx()))
assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 3]).to(F.ctx()),
)
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([5]).to(F.ctx()))
assert torch.equal(mn.seeds, torch.LongTensor([5]).to(F.ctx()))
assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 2]).to(F.ctx()),
......@@ -114,7 +114,7 @@ def test_InSubgraphSampler_homo():
assert torch.equal(original_indices(mn), torch.tensor([1, 4]).to(F.ctx()))
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([3]).to(F.ctx()))
assert torch.equal(mn.seeds, torch.LongTensor([3]).to(F.ctx()))
assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 2]).to(F.ctx()),
......@@ -176,9 +176,7 @@ def test_InSubgraphSampler_hetero():
it = iter(in_subgraph_sampler)
mn = next(it)
assert torch.equal(
mn.seed_nodes["N0"], torch.LongTensor([1, 0]).to(F.ctx())
)
assert torch.equal(mn.seeds["N0"], torch.LongTensor([1, 0]).to(F.ctx()))
expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]),
......@@ -203,7 +201,7 @@ def test_InSubgraphSampler_hetero():
)
mn = next(it)
assert mn.seed_nodes == {
assert mn.seeds == {
"N0": torch.LongTensor([2]).to(F.ctx()),
"N1": torch.LongTensor([0]).to(F.ctx()),
}
......@@ -230,9 +228,7 @@ def test_InSubgraphSampler_hetero():
)
mn = next(it)
assert torch.equal(
mn.seed_nodes["N1"], torch.LongTensor([2, 1]).to(F.ctx())
)
assert torch.equal(mn.seeds["N1"], torch.LongTensor([2, 1]).to(F.ctx()))
expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
......
......@@ -95,9 +95,10 @@ def test_UniformNegativeSampler_node_pairs_invoke():
def _verify(negative_sampler):
for data in negative_sampler:
# Assertation
assert data.negative_srcs is None
assert data.negative_dsts.size(0) == batch_size
assert data.negative_dsts.size(1) == negative_ratio
seeds_len = batch_size + batch_size * negative_ratio
assert data.seeds.size(0) == seeds_len
assert data.labels.size(0) == seeds_len
assert data.indexes.size(0) == seeds_len
# Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler(
......@@ -137,14 +138,30 @@ def test_Uniform_NegativeSampler_node_pairs(negative_ratio):
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst = data.node_pairs
neg_src, neg_dst = data.negative_srcs, data.negative_dsts
expected_labels = torch.empty(
batch_size * (negative_ratio + 1), device=F.ctx()
)
expected_labels[:batch_size] = 1
expected_labels[batch_size:] = 0
expected_indexes = torch.arange(batch_size, device=F.ctx())
expected_indexes = torch.cat(
(
expected_indexes,
expected_indexes.repeat_interleave(negative_ratio),
)
)
expected_neg_src = data.seeds[:batch_size][:, 0].repeat_interleave(
negative_ratio
)
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
assert len(neg_dst) == batch_size
assert neg_src is None
assert neg_dst.numel() == batch_size * negative_ratio
assert data.negative_srcs is None
assert data.negative_dsts is None
assert data.labels is not None
assert data.indexes is not None
assert data.seeds.size(0) == batch_size * (negative_ratio + 1)
assert torch.equal(data.labels, expected_labels)
assert torch.equal(data.indexes, expected_indexes)
assert torch.equal(data.seeds[batch_size:][:, 0], expected_neg_src)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
......
......@@ -15,18 +15,18 @@ from . import gb_test_utils
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
item_sampler = gb.ItemSampler(
gb.ItemSet(torch.arange(20), names="seed_nodes"), 4
gb.ItemSet(torch.arange(20), names="seeds"), 4
)
# Invoke CopyTo via class constructor.
dp = gb.CopyTo(item_sampler, "cuda")
for data in dp:
assert data.seed_nodes.device.type == "cuda"
assert data.seeds.device.type == "cuda"
# Invoke CopyTo via functional form.
dp = item_sampler.copy_to("cuda")
for data in dp:
assert data.seed_nodes.device.type == "cuda"
assert data.seeds.device.type == "cuda"
@pytest.mark.parametrize(
......@@ -37,7 +37,6 @@ def test_CopyTo():
"link_prediction",
"edge_classification",
"extra_attrs",
"other",
],
)
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
......@@ -63,11 +62,6 @@ def test_CopyToWithMiniBatches_original(task):
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("node_pairs", "labels"),
)
else:
itemset = gb.ItemSet(
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("node_pairs", "seed_nodes"),
)
graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True)
features = {}
......@@ -96,38 +90,25 @@ def test_CopyToWithMiniBatches_original(task):
"sampled_subgraphs",
"labels",
"blocks",
"seeds",
]
elif task == "node_inference":
copied_attrs = [
"seed_nodes",
"seeds",
"sampled_subgraphs",
"blocks",
"labels",
]
elif task == "link_prediction":
elif task == "link_prediction" or task == "edge_classification":
copied_attrs = [
"compacted_node_pairs",
"node_features",
"edge_features",
"labels",
"compacted_seeds",
"sampled_subgraphs",
"compacted_negative_srcs",
"compacted_negative_dsts",
"blocks",
"positive_node_pairs",
"negative_node_pairs",
"node_pairs_with_labels",
]
elif task == "edge_classification":
copied_attrs = [
"compacted_node_pairs",
"indexes",
"node_features",
"edge_features",
"sampled_subgraphs",
"labels",
"blocks",
"positive_node_pairs",
"negative_node_pairs",
"node_pairs_with_labels",
"seeds",
]
elif task == "extra_attrs":
copied_attrs = [
......@@ -137,6 +118,7 @@ def test_CopyToWithMiniBatches_original(task):
"labels",
"blocks",
"seed_nodes",
"seeds",
]
def test_data_device(datapipe):
......
......@@ -78,7 +78,7 @@ def test_FeatureFetcher_with_edges_homo():
)
def add_node_and_edge_ids(minibatch):
seeds = minibatch.seed_nodes
seeds = minibatch.seeds
subgraphs = []
for _ in range(3):
sampled_csc = gb.CSCFormatBase(
......@@ -172,7 +172,7 @@ def test_FeatureFetcher_with_edges_hetero():
b = torch.tensor([[random.randint(0, 10)] for _ in range(50)])
def add_node_and_edge_ids(minibatch):
seeds = minibatch.seed_nodes
seeds = minibatch.seeds
subgraphs = []
original_edge_ids = {
"n1:e1:n2": torch.randint(0, 50, (10,)),
......
......@@ -60,69 +60,85 @@ def test_integration_link_prediction():
)
expected = [
str(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=tensor([[5, 1],
[3, 2],
[3, 2],
[3, 3],
[5, 0],
[5, 0],
[3, 3],
[3, 0],
[3, 5],
[3, 3],
[3, 3],
[3, 4]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
indices=tensor([0, 4], dtype=torch.int32),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 2, 2, 2, 3], dtype=torch.int32),
indices=tensor([0, 5, 4], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_row_node_ids=tensor([5, 1, 3, 2, 0, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
indices=tensor([5, 4], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_row_node_ids=tensor([5, 1, 3, 2, 0, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
)],
positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
node_pairs_with_labels=((tensor([0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1]), tensor([2, 3, 3, 1, 4, 4, 1, 4, 0, 1, 1, 5])),
tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])),
node_pairs=(tensor([5, 3, 3, 3]),
tensor([1, 2, 2, 3])),
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.8672, 0.2276],
[0.6172, 0.7865],
[0.8672, 0.2276],
[0.2109, 0.1089],
[0.9634, 0.2294],
[0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=(tensor([[0, 0],
[1, 1],
[1, 1],
[1, 1]]),
tensor([[4, 4],
[1, 4],
[0, 1],
[1, 5]])),
negative_dsts=tensor([[0, 0],
[3, 0],
[5, 3],
[3, 4]]),
labels=None,
input_nodes=tensor([5, 3, 1, 2, 0, 4]),
indexes=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
input_nodes=tensor([5, 1, 3, 2, 0, 4]),
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
compacted_seeds=tensor([[0, 1],
[2, 3],
[2, 3],
[2, 2],
[0, 4],
[0, 4],
[2, 2],
[2, 4],
[2, 0],
[2, 2],
[2, 2],
[2, 5]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=tensor([[4, 4],
[1, 4],
[0, 1],
[1, 5]]),
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],
)"""
),
str(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=tensor([[3, 3],
[4, 3],
[4, 4],
[0, 4],
[3, 1],
[3, 5],
[4, 2],
[4, 5],
[4, 4],
[4, 3],
[0, 1],
[0, 5]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32),
indices=tensor([4, 1, 0], dtype=torch.int32),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([4, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_edge_ids=None,
......@@ -135,12 +151,9 @@ def test_integration_link_prediction():
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
)],
positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
node_pairs_with_labels=((tensor([0, 1, 1, 2, 0, 0, 1, 1, 1, 1, 2, 2]), tensor([0, 0, 1, 1, 3, 4, 5, 4, 1, 0, 3, 4])),
tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])),
node_pairs=(tensor([3, 4, 4, 0]),
tensor([3, 3, 4, 4])),
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223],
[0.9634, 0.2294],
......@@ -148,37 +161,39 @@ def test_integration_link_prediction():
[0.5160, 0.2486],
[0.2109, 0.1089]])},
negative_srcs=None,
negative_node_pairs=(tensor([[0, 0],
[1, 1],
[1, 1],
[2, 2]]),
tensor([[3, 4],
[5, 4],
[1, 0],
[3, 4]])),
negative_dsts=tensor([[1, 5],
[2, 5],
[4, 3],
[1, 5]]),
labels=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
input_nodes=tensor([3, 4, 0, 1, 5, 2]),
indexes=None,
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
compacted_seeds=tensor([[0, 0],
[1, 0],
[1, 1],
[2, 1],
[0, 3],
[0, 4],
[1, 5],
[1, 4],
[1, 1],
[1, 0],
[2, 3],
[2, 4]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=tensor([[3, 4],
[5, 4],
[1, 0],
[3, 4]]),
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)],
)"""
),
str(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=tensor([[5, 5],
[4, 5],
[5, 0],
[5, 4],
[4, 0],
[4, 1]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
......@@ -194,34 +209,30 @@ def test_integration_link_prediction():
original_edge_ids=None,
original_column_node_ids=tensor([5, 4, 0, 1]),
)],
positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
node_pairs_with_labels=((tensor([0, 1, 0, 0, 1, 1]), tensor([0, 0, 2, 1, 2, 3])),
tensor([1., 1., 0., 0., 0., 0.])),
node_pairs=(tensor([5, 4]),
tensor([5, 5])),
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223],
[0.9634, 0.2294],
[0.6172, 0.7865]])},
negative_srcs=None,
negative_node_pairs=(tensor([[0, 0],
[1, 1]]),
tensor([[2, 1],
[2, 3]])),
negative_dsts=tensor([[0, 4],
[0, 1]]),
labels=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 0., 0., 0., 0.]),
input_nodes=tensor([5, 4, 0, 1]),
indexes=None,
indexes=tensor([0, 1, 0, 0, 1, 1]),
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
compacted_seeds=tensor([[0, 0],
[1, 0],
[0, 2],
[0, 1],
[1, 2],
[1, 3]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=tensor([[2, 1],
[2, 3]]),
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)],
)"""
......@@ -285,43 +296,46 @@ def test_integration_node_classification():
)
expected = [
str(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=tensor([[5, 1],
[3, 2],
[3, 2],
[3, 3]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([4, 1, 0, 1], dtype=torch.int32),
indices=tensor([4, 0, 2, 2], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 3, 1, 2, 4]),
original_row_node_ids=tensor([5, 1, 3, 2, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]),
original_column_node_ids=tensor([5, 1, 3, 2]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([0, 1, 0, 1], dtype=torch.int32),
indices=tensor([0, 0, 2, 2], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 3, 1, 2]),
original_row_node_ids=tensor([5, 1, 3, 2]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]),
original_column_node_ids=tensor([5, 1, 3, 2]),
)],
positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=(tensor([5, 3, 3, 3]),
tensor([1, 2, 2, 3])),
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.8672, 0.2276],
[0.6172, 0.7865],
[0.8672, 0.2276],
[0.2109, 0.1089],
[0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=tensor([5, 3, 1, 2, 4]),
input_nodes=tensor([5, 1, 3, 2, 4]),
indexes=None,
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
compacted_seeds=tensor([[0, 1],
[2, 3],
[2, 3],
[2, 2]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
......@@ -329,7 +343,10 @@ def test_integration_node_classification():
)"""
),
str(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=tensor([[3, 3],
[4, 3],
[4, 4],
[0, 4]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
......@@ -345,11 +362,9 @@ def test_integration_node_classification():
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
)],
positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=(tensor([3, 4, 4, 0]),
tensor([3, 3, 4, 4])),
node_pairs=None,
node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223],
[0.9634, 0.2294]])},
......@@ -361,9 +376,11 @@ def test_integration_node_classification():
indexes=None,
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
compacted_seeds=tensor([[0, 0],
[1, 0],
[1, 1],
[2, 1]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
......@@ -371,7 +388,8 @@ def test_integration_node_classification():
)"""
),
str(
"""MiniBatch(seeds=None,
"""MiniBatch(seeds=tensor([[5, 5],
[4, 5]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
......@@ -387,11 +405,9 @@ def test_integration_node_classification():
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
)],
positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=(tensor([5, 4]),
tensor([5, 5])),
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223],
[0.9634, 0.2294]])},
......@@ -403,9 +419,9 @@ def test_integration_node_classification():
indexes=None,
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
compacted_seeds=tensor([[0, 0],
[1, 0]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
......
......@@ -49,20 +49,20 @@ def test_ItemSampler_minibatcher():
item_sampler = gb.ItemSampler(item_set, batch_size=4)
minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert len(minibatch.seed_nodes) == 4
assert minibatch.seeds is not None
assert len(minibatch.seeds) == 4
# Customized minibatcher is used if specified.
def minibatcher(batch, names):
return gb.MiniBatch(seed_nodes=batch)
return gb.MiniBatch(seeds=batch)
item_sampler = gb.ItemSampler(
item_set, batch_size=4, minibatcher=minibatcher
)
minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert len(minibatch.seed_nodes) == 4
assert minibatch.seeds is not None
assert len(minibatch.seeds) == 4
@pytest.mark.parametrize("batch_size", [1, 4])
......@@ -83,17 +83,17 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.seeds is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size
assert len(minibatch.seeds) == batch_size
else:
if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size
assert len(minibatch.seeds) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch.seed_nodes)
minibatch_ids.append(minibatch.seeds)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
......@@ -111,17 +111,17 @@ def test_ItemSet_integer(batch_size, shuffle, drop_last):
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.seeds is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size
assert len(minibatch.seeds) == batch_size
else:
if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size
assert len(minibatch.seeds) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch.seed_nodes)
minibatch_ids.append(minibatch.seeds)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
......@@ -140,17 +140,17 @@ def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last):
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.seeds is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size
assert len(minibatch.seeds) == batch_size
else:
if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size
assert len(minibatch.seeds) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch.seed_nodes)
minibatch_ids.append(minibatch.seeds)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
......@@ -171,18 +171,18 @@ def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):
minibatch_labels = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.seeds is not None
assert minibatch.labels is not None
assert len(minibatch.seed_nodes) == len(minibatch.labels)
assert len(minibatch.seeds) == len(minibatch.labels)
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size
assert len(minibatch.seeds) == batch_size
else:
if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size
assert len(minibatch.seeds) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch.seed_nodes)
minibatch_ids.append(minibatch.seeds)
minibatch_labels.append(minibatch.labels)
minibatch_ids = torch.cat(minibatch_ids)
minibatch_labels = torch.cat(minibatch_labels)
......@@ -254,10 +254,10 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
src_ids = []
dst_ids = []
for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert isinstance(minibatch.node_pairs, tuple)
assert minibatch.seeds is not None
assert isinstance(minibatch.seeds, torch.Tensor)
assert minibatch.labels is None
src, dst = minibatch.node_pairs
src, dst = minibatch.seeds.T
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -295,10 +295,10 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
dst_ids = []
labels = []
for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert isinstance(minibatch.node_pairs, tuple)
assert minibatch.seeds is not None
assert isinstance(minibatch.seeds, torch.Tensor)
assert minibatch.labels is not None
src, dst = minibatch.node_pairs
src, dst = minibatch.seeds.T
label = minibatch.labels
assert len(src) == len(dst)
assert len(src) == len(label)
......@@ -349,11 +349,13 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
dst_ids = []
negs_ids = []
for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert isinstance(minibatch.node_pairs, tuple)
assert minibatch.negative_dsts is not None
src, dst = minibatch.node_pairs
negs = minibatch.negative_dsts
assert minibatch.seeds is not None
assert isinstance(minibatch.seeds, torch.Tensor)
assert minibatch.labels is not None
assert minibatch.indexes is not None
src, dst = minibatch.seeds.T
negs_src = src[~minibatch.labels.to(bool)]
negs_dst = dst[~minibatch.labels.to(bool)]
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -362,25 +364,32 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
expected_batch_size = num_ids % batch_size
else:
assert False
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
assert negs.dim() == 2
assert negs.shape[0] == expected_batch_size
assert negs.shape[1] == num_negs
assert len(src) == expected_batch_size * 3
assert len(dst) == expected_batch_size * 3
assert negs_src.dim() == 1
assert negs_dst.dim() == 1
assert len(negs_src) == expected_batch_size * 2
assert len(negs_dst) == expected_batch_size * 2
expected_indexes = torch.arange(expected_batch_size)
expected_indexes = torch.cat(
(expected_indexes, expected_indexes.repeat_interleave(2))
)
assert torch.equal(minibatch.indexes, expected_indexes)
# Verify node pairs and negative destinations.
assert torch.equal(src + 1, dst)
assert torch.equal(negs[:, 0] + 1, negs[:, 1])
assert torch.equal(
src[minibatch.labels.to(bool)] + 1, dst[minibatch.labels.to(bool)]
)
assert torch.equal((negs_dst - 2 * num_ids) // 2 * 2, negs_src)
# Archive batch.
src_ids.append(src)
dst_ids.append(dst)
negs_ids.append(negs)
negs_ids.append(negs_dst)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
negs_ids = torch.cat(negs_ids)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1, 0] <= negs_ids[1:, 0]) is not shuffle
assert torch.all(negs_ids[:-1, 1] <= negs_ids[1:, 1]) is not shuffle
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
......@@ -472,7 +481,7 @@ def test_append_with_other_datapipes():
data_pipe = data_pipe.enumerate()
for i, (idx, data) in enumerate(data_pipe):
assert i == idx
assert len(data.seed_nodes) == batch_size
assert len(data.seeds) == batch_size
@pytest.mark.parametrize("batch_size", [1, 4])
......@@ -510,9 +519,9 @@ def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last):
else:
assert False
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.seeds is not None
ids = []
for _, v in minibatch.seed_nodes.items():
for _, v in minibatch.seeds.items():
ids.append(v)
ids = torch.cat(ids)
assert len(ids) == expected_batch_size
......@@ -549,9 +558,9 @@ def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last):
else:
assert False
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.seeds is not None
ids = []
for _, v in minibatch.seed_nodes.items():
for _, v in minibatch.seeds.items():
ids.append(v)
ids = torch.cat(ids)
assert len(ids) == expected_batch_size
......@@ -587,7 +596,7 @@ def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
minibatch_labels = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.seeds is not None
assert minibatch.labels is not None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
......@@ -598,7 +607,7 @@ def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
else:
assert False
ids = []
for _, v in minibatch.seed_nodes.items():
for _, v in minibatch.seeds.items():
ids.append(v)
ids = torch.cat(ids)
assert len(ids) == expected_batch_size
......@@ -638,7 +647,7 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
dst_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None
assert minibatch.seeds is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= total_pairs
if not is_last or total_pairs % batch_size == 0:
......@@ -650,10 +659,10 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
assert False
src = []
dst = []
for _, (node_pairs) in minibatch.node_pairs.items():
assert isinstance(node_pairs, tuple)
src.append(node_pairs[0])
dst.append(node_pairs[1])
for _, (seeds) in minibatch.seeds.items():
assert isinstance(seeds, torch.Tensor)
src.append(seeds[:, 0])
dst.append(seeds[:, 1])
src = torch.cat(src)
dst = torch.cat(dst)
assert len(src) == expected_batch_size
......@@ -696,8 +705,9 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
labels = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None
assert minibatch.seeds is not None
assert minibatch.labels is not None
assert minibatch.negative_dsts is None
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -709,10 +719,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src = []
dst = []
label = []
for _, node_pairs in minibatch.node_pairs.items():
assert isinstance(node_pairs, tuple)
src.append(node_pairs[0])
dst.append(node_pairs[1])
for _, seeds in minibatch.seeds.items():
assert isinstance(seeds, torch.Tensor)
src.append(seeds[:, 0])
dst.append(seeds[:, 1])
for _, v_label in minibatch.labels.items():
label.append(v_label)
src = torch.cat(src)
......@@ -769,8 +779,9 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
negs_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None
assert minibatch.negative_dsts is not None
assert minibatch.seeds is not None
assert minibatch.labels is not None
assert minibatch.negative_dsts is None
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -781,33 +792,37 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert False
src = []
dst = []
negs = []
for _, node_pairs in minibatch.node_pairs.items():
assert isinstance(node_pairs, tuple)
src.append(node_pairs[0])
dst.append(node_pairs[1])
for _, v_negs in minibatch.negative_dsts.items():
negs.append(v_negs)
negs_src = []
negs_dst = []
for etype, seeds in minibatch.seeds.items():
assert isinstance(seeds, torch.Tensor)
src_etype = seeds[:, 0]
dst_etype = seeds[:, 1]
src.append(src_etype[minibatch.labels[etype].to(bool)])
dst.append(dst_etype[minibatch.labels[etype].to(bool)])
negs_src.append(src_etype[~minibatch.labels[etype].to(bool)])
negs_dst.append(dst_etype[~minibatch.labels[etype].to(bool)])
src = torch.cat(src)
dst = torch.cat(dst)
negs = torch.cat(negs)
negs_src = torch.cat(negs_src)
negs_dst = torch.cat(negs_dst)
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
assert len(negs) == expected_batch_size
assert len(negs_src) == expected_batch_size * 2
assert len(negs_dst) == expected_batch_size * 2
src_ids.append(src)
dst_ids.append(dst)
negs_ids.append(negs)
assert negs.dim() == 2
assert negs.shape[0] == expected_batch_size
assert negs.shape[1] == num_negs
negs_ids.append(negs_dst)
assert negs_src.dim() == 1
assert negs_dst.dim() == 1
assert torch.equal(src + 1, dst)
assert torch.equal(negs[:, 0] + 1, negs[:, 1])
assert torch.equal(negs_src, (negs_dst - num_ids * 4) // 2 * 2)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
negs_ids = torch.cat(negs_ids)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
assert torch.all(negs_ids <= negs_ids) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
......@@ -961,10 +976,10 @@ def distributed_item_sampler_subprocess(
sampled_count = torch.zeros(num_ids, dtype=torch.int32)
for i in data_loader:
# Count how many times each item is sampled.
sampled_count[i.seed_nodes] += 1
sampled_count[i.seeds] += 1
if drop_last:
assert i.seed_nodes.size(0) == batch_size
num_items += i.seed_nodes.size(0)
assert i.seeds.size(0) == batch_size
num_items += i.seeds.size(0)
num_batches = len(list(item_sampler))
if drop_uneven_inputs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment