Unverified Commit 6133cecd authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Update tests related to `seeds` refactor. (#7352)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent ba9c1521
...@@ -81,7 +81,7 @@ def test_InSubgraphSampler_homo(): ...@@ -81,7 +81,7 @@ def test_InSubgraphSampler_homo():
graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx()) graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())
seed_nodes = torch.LongTensor([0, 5, 3]) seed_nodes = torch.LongTensor([0, 5, 3])
item_set = gb.ItemSet(seed_nodes, names="seed_nodes") item_set = gb.ItemSet(seed_nodes, names="seeds")
batch_size = 1 batch_size = 1
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx() F.ctx()
...@@ -162,8 +162,8 @@ def test_InSubgraphSampler_hetero(): ...@@ -162,8 +162,8 @@ def test_InSubgraphSampler_hetero():
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"N0": gb.ItemSet(torch.LongTensor([1, 0, 2]), names="seed_nodes"), "N0": gb.ItemSet(torch.LongTensor([1, 0, 2]), names="seeds"),
"N1": gb.ItemSet(torch.LongTensor([0, 2, 1]), names="seed_nodes"), "N1": gb.ItemSet(torch.LongTensor([0, 2, 1]), names="seeds"),
} }
) )
batch_size = 2 batch_size = 2
......
...@@ -47,7 +47,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted): ...@@ -47,7 +47,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
items = torch.arange(3) items = torch.arange(3)
else: else:
items = torch.tensor([2, 0, 1]) items = torch.tensor([2, 0, 1])
names = "seed_nodes" names = "seeds"
itemset = gb.ItemSet(items, names=names) itemset = gb.ItemSet(items, names=names)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
if hetero: if hetero:
...@@ -94,9 +94,7 @@ def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch): ...@@ -94,9 +94,7 @@ def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch):
).to(F.ctx()) ).to(F.ctx())
torch.random.set_rng_state(torch.manual_seed(123).get_state()) torch.random.set_rng_state(torch.manual_seed(123).get_state())
batch_dependency = 100 batch_dependency = 100
itemset = gb.ItemSet( itemset = gb.ItemSet(torch.zeros(batch_dependency + 1).int(), names="seeds")
torch.zeros(batch_dependency + 1).int(), names="seed_nodes"
)
datapipe = gb.ItemSampler(itemset, batch_size=1).copy_to(F.ctx()) datapipe = gb.ItemSampler(itemset, batch_size=1).copy_to(F.ctx())
fanouts = [5, 5] fanouts = [5, 5]
datapipe = datapipe.sample_layer_neighbor( datapipe = datapipe.sample_layer_neighbor(
......
...@@ -25,7 +25,7 @@ def test_FeatureFetcher_invoke(): ...@@ -25,7 +25,7 @@ def test_FeatureFetcher_invoke():
features[keys[1]] = gb.TorchBasedFeature(b) features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
...@@ -58,7 +58,7 @@ def test_FeatureFetcher_homo(): ...@@ -58,7 +58,7 @@ def test_FeatureFetcher_homo():
features[keys[1]] = gb.TorchBasedFeature(b) features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
...@@ -104,7 +104,7 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -104,7 +104,7 @@ def test_FeatureFetcher_with_edges_homo():
features[keys[1]] = gb.TorchBasedFeature(b) features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids) converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"]) fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
...@@ -152,8 +152,8 @@ def test_FeatureFetcher_hetero(): ...@@ -152,8 +152,8 @@ def test_FeatureFetcher_hetero():
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seed_nodes"), "n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seeds"),
"n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seed_nodes"), "n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seeds"),
} }
) )
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
...@@ -215,7 +215,7 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -215,7 +215,7 @@ def test_FeatureFetcher_with_edges_hetero():
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1": gb.ItemSet(torch.randint(0, 20, (10,)), names="seed_nodes"), "n1": gb.ItemSet(torch.randint(0, 20, (10,)), names="seeds"),
} }
) )
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
......
...@@ -1161,7 +1161,7 @@ def test_DistributedItemSampler( ...@@ -1161,7 +1161,7 @@ def test_DistributedItemSampler(
): ):
nprocs = 4 nprocs = 4
batch_size = 4 batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds")
# On Windows, if the process group initialization file already exists, # On Windows, if the process group initialization file already exists,
# the program may hang. So we need to delete it if it exists. # the program may hang. So we need to delete it if it exists.
......
...@@ -8,15 +8,15 @@ from dgl import graphbolt as gb ...@@ -8,15 +8,15 @@ from dgl import graphbolt as gb
def test_ItemSet_names(): def test_ItemSet_names():
# ItemSet with single name. # ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seeds",)
# ItemSet with multiple names. # ItemSet with multiple names.
item_set = gb.ItemSet( item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
) )
assert item_set.names == ("seed_nodes", "labels") assert item_set.names == ("seeds", "labels")
# ItemSet without name. # ItemSet without name.
item_set = gb.ItemSet(torch.arange(0, 5)) item_set = gb.ItemSet(torch.arange(0, 5))
...@@ -27,19 +27,19 @@ def test_ItemSet_names(): ...@@ -27,19 +27,19 @@ def test_ItemSet_names():
AssertionError, AssertionError,
match=re.escape("Number of items (1) and names (2) must match."), match=re.escape("Number of items (1) and names (2) must match."),
): ):
_ = gb.ItemSet(5, names=("seed_nodes", "labels")) _ = gb.ItemSet(5, names=("seeds", "labels"))
# ItemSet with mismatched items and names. # ItemSet with mismatched items and names.
with pytest.raises( with pytest.raises(
AssertionError, AssertionError,
match=re.escape("Number of items (1) and names (2) must match."), match=re.escape("Number of items (1) and names (2) must match."),
): ):
_ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels")) _ = gb.ItemSet(torch.arange(0, 5), names=("seeds", "labels"))
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("dtype", [torch.int32, torch.int64])
def test_ItemSet_scalar_dtype(dtype): def test_ItemSet_scalar_dtype(dtype):
item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seed_nodes") item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seeds")
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item assert i == item
assert item.dtype == dtype assert item.dtype == dtype
...@@ -106,8 +106,8 @@ def test_ItemSet_length(): ...@@ -106,8 +106,8 @@ def test_ItemSet_length():
def test_ItemSet_seed_nodes(): def test_ItemSet_seed_nodes():
# Node IDs with tensor. # Node IDs with tensor.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seeds",)
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item.item() assert i == item.item()
...@@ -118,8 +118,8 @@ def test_ItemSet_seed_nodes(): ...@@ -118,8 +118,8 @@ def test_ItemSet_seed_nodes():
assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5)) assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))
# Node IDs with single integer. # Node IDs with single integer.
item_set = gb.ItemSet(5, names="seed_nodes") item_set = gb.ItemSet(5, names="seeds")
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seeds",)
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item.item() assert i == item.item()
...@@ -145,8 +145,8 @@ def test_ItemSet_seed_nodes_labels(): ...@@ -145,8 +145,8 @@ def test_ItemSet_seed_nodes_labels():
# Node IDs and labels. # Node IDs and labels.
seed_nodes = torch.arange(0, 5) seed_nodes = torch.arange(0, 5)
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels")) item_set = gb.ItemSet((seed_nodes, labels), names=("seeds", "labels"))
assert item_set.names == ("seed_nodes", "labels") assert item_set.names == ("seeds", "labels")
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, (seed_node, label) in enumerate(item_set): for i, (seed_node, label) in enumerate(item_set):
assert seed_node == seed_nodes[i] assert seed_node == seed_nodes[i]
...@@ -164,8 +164,8 @@ def test_ItemSet_seed_nodes_labels(): ...@@ -164,8 +164,8 @@ def test_ItemSet_seed_nodes_labels():
def test_ItemSet_node_pairs(): def test_ItemSet_node_pairs():
# Node pairs. # Node pairs.
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
item_set = gb.ItemSet(node_pairs, names="node_pairs") item_set = gb.ItemSet(node_pairs, names="seeds")
assert item_set.names == ("node_pairs",) assert item_set.names == ("seeds",)
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, (src, dst) in enumerate(item_set): for i, (src, dst) in enumerate(item_set):
assert node_pairs[i][0] == src assert node_pairs[i][0] == src
...@@ -182,8 +182,8 @@ def test_ItemSet_node_pairs_labels(): ...@@ -182,8 +182,8 @@ def test_ItemSet_node_pairs_labels():
# Node pairs and labels # Node pairs and labels
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels")) item_set = gb.ItemSet((node_pairs, labels), names=("seeds", "labels"))
assert item_set.names == ("node_pairs", "labels") assert item_set.names == ("seeds", "labels")
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, (node_pair, label) in enumerate(item_set): for i, (node_pair, label) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair) assert torch.equal(node_pairs[i], node_pair)
...@@ -198,26 +198,31 @@ def test_ItemSet_node_pairs_labels(): ...@@ -198,26 +198,31 @@ def test_ItemSet_node_pairs_labels():
assert torch.equal(item_set[torch.arange(0, 5)][1], labels) assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
def test_ItemSet_node_pairs_neg_dsts(): def test_ItemSet_node_pairs_labels_indexes():
# Node pairs and negative destinations. # Node pairs and negative destinations.
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
neg_dsts = torch.arange(10, 25).reshape(-1, 3) labels = torch.tensor([1, 1, 0, 0, 0])
indexes = torch.tensor([0, 1, 0, 0, 1])
item_set = gb.ItemSet( item_set = gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
) )
assert item_set.names == ("node_pairs", "negative_dsts") assert item_set.names == ("seeds", "labels", "indexes")
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, (node_pair, neg_dst) in enumerate(item_set): for i, (node_pair, label, index) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair) assert torch.equal(node_pairs[i], node_pair)
assert torch.equal(neg_dsts[i], neg_dst) assert torch.equal(labels[i], label)
assert torch.equal(indexes[i], index)
assert torch.equal(node_pairs[i], item_set[i][0]) assert torch.equal(node_pairs[i], item_set[i][0])
assert torch.equal(neg_dsts[i], item_set[i][1]) assert torch.equal(labels[i], item_set[i][1])
assert torch.equal(indexes[i], item_set[i][2])
# Indexing with a slice. # Indexing with a slice.
assert torch.equal(item_set[:][0], node_pairs) assert torch.equal(item_set[:][0], node_pairs)
assert torch.equal(item_set[:][1], neg_dsts) assert torch.equal(item_set[:][1], labels)
assert torch.equal(item_set[:][2], indexes)
# Indexing with an Iterable. # Indexing with an Iterable.
assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs) assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
assert torch.equal(item_set[torch.arange(0, 5)][1], neg_dsts) assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
assert torch.equal(item_set[torch.arange(0, 5)][2], indexes)
def test_ItemSet_graphs(): def test_ItemSet_graphs():
...@@ -237,26 +242,26 @@ def test_ItemSetDict_names(): ...@@ -237,26 +242,26 @@ def test_ItemSetDict_names():
# ItemSetDict with single name. # ItemSetDict with single name.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"), "user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"), "item": gb.ItemSet(torch.arange(5, 10), names="seeds"),
} }
) )
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seeds",)
# ItemSetDict with multiple names. # ItemSetDict with multiple names.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user": gb.ItemSet( "user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
"item": gb.ItemSet( "item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)), (torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
} }
) )
assert item_set.names == ("seed_nodes", "labels") assert item_set.names == ("seeds", "labels")
# ItemSetDict with no name. # ItemSetDict with no name.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
...@@ -276,11 +281,9 @@ def test_ItemSetDict_names(): ...@@ -276,11 +281,9 @@ def test_ItemSetDict_names():
{ {
"user": gb.ItemSet( "user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
),
"item": gb.ItemSet(
(torch.arange(5, 10),), names=("seed_nodes",)
), ),
"item": gb.ItemSet((torch.arange(5, 10),), names=("seeds",)),
} }
) )
...@@ -354,14 +357,14 @@ def test_ItemSetDict_iteration_seed_nodes(): ...@@ -354,14 +357,14 @@ def test_ItemSetDict_iteration_seed_nodes():
user_ids = torch.arange(0, 5) user_ids = torch.arange(0, 5)
item_ids = torch.arange(5, 10) item_ids = torch.arange(5, 10)
ids = { ids = {
"user": gb.ItemSet(user_ids, names="seed_nodes"), "user": gb.ItemSet(user_ids, names="seeds"),
"item": gb.ItemSet(item_ids, names="seed_nodes"), "item": gb.ItemSet(item_ids, names="seeds"),
} }
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
chained_ids += [(key, v) for v in value] chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids) item_set = gb.ItemSetDict(ids)
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seeds",)
# Iterating over ItemSetDict and indexing one by one. # Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
...@@ -413,18 +416,14 @@ def test_ItemSetDict_iteration_seed_nodes_labels(): ...@@ -413,18 +416,14 @@ def test_ItemSetDict_iteration_seed_nodes_labels():
item_ids = torch.arange(5, 10) item_ids = torch.arange(5, 10)
item_labels = torch.randint(0, 3, (5,)) item_labels = torch.randint(0, 3, (5,))
ids_labels = { ids_labels = {
"user": gb.ItemSet( "user": gb.ItemSet((user_ids, user_labels), names=("seeds", "labels")),
(user_ids, user_labels), names=("seed_nodes", "labels") "item": gb.ItemSet((item_ids, item_labels), names=("seeds", "labels")),
),
"item": gb.ItemSet(
(item_ids, item_labels), names=("seed_nodes", "labels")
),
} }
chained_ids = [] chained_ids = []
for key, value in ids_labels.items(): for key, value in ids_labels.items():
chained_ids += [(key, v) for v in value] chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids_labels) item_set = gb.ItemSetDict(ids_labels)
assert item_set.names == ("seed_nodes", "labels") assert item_set.names == ("seeds", "labels")
# Iterating over ItemSetDict and indexing one by one. # Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
...@@ -443,14 +442,14 @@ def test_ItemSetDict_iteration_node_pairs(): ...@@ -443,14 +442,14 @@ def test_ItemSetDict_iteration_node_pairs():
# Node pairs. # Node pairs.
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
node_pairs_dict = { node_pairs_dict = {
"user:like:item": gb.ItemSet(node_pairs, names="node_pairs"), "user:like:item": gb.ItemSet(node_pairs, names="seeds"),
"user:follow:user": gb.ItemSet(node_pairs, names="node_pairs"), "user:follow:user": gb.ItemSet(node_pairs, names="seeds"),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
assert item_set.names == ("node_pairs",) assert item_set.names == ("seeds",)
# Iterating over ItemSetDict and indexing one by one. # Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
...@@ -471,17 +470,17 @@ def test_ItemSetDict_iteration_node_pairs_labels(): ...@@ -471,17 +470,17 @@ def test_ItemSetDict_iteration_node_pairs_labels():
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
node_pairs_labels = { node_pairs_labels = {
"user:like:item": gb.ItemSet( "user:like:item": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels") (node_pairs, labels), names=("seeds", "labels")
), ),
"user:follow:user": gb.ItemSet( "user:follow:user": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels") (node_pairs, labels), names=("seeds", "labels")
), ),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_labels.items(): for key, value in node_pairs_labels.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_labels) item_set = gb.ItemSetDict(node_pairs_labels)
assert item_set.names == ("node_pairs", "labels") assert item_set.names == ("seeds", "labels")
# Iterating over ItemSetDict and indexing one by one. # Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
...@@ -501,23 +500,24 @@ def test_ItemSetDict_iteration_node_pairs_labels(): ...@@ -501,23 +500,24 @@ def test_ItemSetDict_iteration_node_pairs_labels():
assert torch.equal(item_set[:]["user:follow:user"][1], labels) assert torch.equal(item_set[:]["user:follow:user"][1], labels)
def test_ItemSetDict_iteration_node_pairs_neg_dsts(): def test_ItemSetDict_iteration_node_pairs_labels_indexes():
# Node pairs and negative destinations. # Node pairs and negative destinations.
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
neg_dsts = torch.arange(10, 25).reshape(-1, 3) labels = torch.tensor([1, 1, 0, 0, 0])
indexes = torch.tensor([0, 1, 0, 0, 1])
node_pairs_neg_dsts = { node_pairs_neg_dsts = {
"user:like:item": gb.ItemSet( "user:like:item": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
), ),
"user:follow:user": gb.ItemSet( "user:follow:user": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
), ),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_neg_dsts.items(): for key, value in node_pairs_neg_dsts.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_neg_dsts) item_set = gb.ItemSetDict(node_pairs_neg_dsts)
assert item_set.names == ("node_pairs", "negative_dsts") assert item_set.names == ("seeds", "labels", "indexes")
# Iterating over ItemSetDict and indexing one by one. # Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
...@@ -526,24 +526,28 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts(): ...@@ -526,24 +526,28 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
assert key in item assert key in item
assert torch.equal(item[key][0], value[0]) assert torch.equal(item[key][0], value[0])
assert torch.equal(item[key][1], value[1]) assert torch.equal(item[key][1], value[1])
assert torch.equal(item[key][2], value[2])
assert item_set[i].keys() == item.keys() assert item_set[i].keys() == item.keys()
key = list(item.keys())[0] key = list(item.keys())[0]
assert torch.equal(item_set[i][key][0], item[key][0]) assert torch.equal(item_set[i][key][0], item[key][0])
assert torch.equal(item_set[i][key][1], item[key][1]) assert torch.equal(item_set[i][key][1], item[key][1])
assert torch.equal(item_set[i][key][2], item[key][2])
# Indexing with a slice. # Indexing with a slice.
assert torch.equal(item_set[:]["user:like:item"][0], node_pairs) assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
assert torch.equal(item_set[:]["user:like:item"][1], neg_dsts) assert torch.equal(item_set[:]["user:like:item"][1], labels)
assert torch.equal(item_set[:]["user:like:item"][2], indexes)
assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs) assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
assert torch.equal(item_set[:]["user:follow:user"][1], neg_dsts) assert torch.equal(item_set[:]["user:follow:user"][1], labels)
assert torch.equal(item_set[:]["user:follow:user"][2], indexes)
def test_ItemSet_repr(): def test_ItemSet_repr():
# ItemSet with single name. # ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
expected_str = ( expected_str = (
"ItemSet(\n" "ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
")" ")"
) )
...@@ -552,12 +556,12 @@ def test_ItemSet_repr(): ...@@ -552,12 +556,12 @@ def test_ItemSet_repr():
# ItemSet with multiple names. # ItemSet with multiple names.
item_set = gb.ItemSet( item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
) )
expected_str = ( expected_str = (
"ItemSet(\n" "ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
")" ")"
) )
assert str(item_set) == expected_str, item_set assert str(item_set) == expected_str, item_set
...@@ -567,20 +571,20 @@ def test_ItemSetDict_repr(): ...@@ -567,20 +571,20 @@ def test_ItemSetDict_repr():
# ItemSetDict with single name. # ItemSetDict with single name.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"), "user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"), "item": gb.ItemSet(torch.arange(5, 10), names="seeds"),
} }
) )
expected_str = ( expected_str = (
"ItemSetDict(\n" "ItemSetDict(\n"
" itemsets={'user': ItemSet(\n" " itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" ), 'item': ItemSet(\n" " ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n" " items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" )},\n" " )},\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
")" ")"
) )
assert str(item_set) == expected_str, item_set assert str(item_set) == expected_str, item_set
...@@ -590,11 +594,11 @@ def test_ItemSetDict_repr(): ...@@ -590,11 +594,11 @@ def test_ItemSetDict_repr():
{ {
"user": gb.ItemSet( "user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
"item": gb.ItemSet( "item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)), (torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
} }
) )
...@@ -602,12 +606,12 @@ def test_ItemSetDict_repr(): ...@@ -602,12 +606,12 @@ def test_ItemSetDict_repr():
"ItemSetDict(\n" "ItemSetDict(\n"
" itemsets={'user': ItemSet(\n" " itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
" ), 'item': ItemSet(\n" " ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n" " items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
" )},\n" " )},\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
")" ")"
) )
assert str(item_set) == expected_str, item_set assert str(item_set) == expected_str, item_set
...@@ -563,11 +563,10 @@ def test_dgl_link_predication_homo(): ...@@ -563,11 +563,10 @@ def test_dgl_link_predication_homo():
check_dgl_blocks_homo(minibatch, dgl_blocks) check_dgl_blocks_homo(minibatch, dgl_blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) def test_dgl_link_predication_hetero():
def test_dgl_link_predication_hetero(mode):
# Arrange # Arrange
minibatch = create_hetero_minibatch() minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = { minibatch.compacted_seeds = {
relation: (torch.tensor([[1, 1, 2, 0, 1, 2], [1, 0, 1, 1, 0, 0]]).T,), relation: (torch.tensor([[1, 1, 2, 0, 1, 2], [1, 0, 1, 1, 0, 0]]).T,),
reverse_relation: ( reverse_relation: (
torch.tensor([[0, 1, 1, 2, 0, 2], [1, 0, 1, 1, 0, 0]]).T, torch.tensor([[0, 1, 1, 2, 0, 2], [1, 0, 1, 1, 0, 0]]).T,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment