Unverified Commit 6133cecd authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Update tests related to `seeds` refactor. (#7352)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent ba9c1521
......@@ -81,7 +81,7 @@ def test_InSubgraphSampler_homo():
graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())
seed_nodes = torch.LongTensor([0, 5, 3])
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
item_set = gb.ItemSet(seed_nodes, names="seeds")
batch_size = 1
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
......@@ -162,8 +162,8 @@ def test_InSubgraphSampler_hetero():
item_set = gb.ItemSetDict(
{
"N0": gb.ItemSet(torch.LongTensor([1, 0, 2]), names="seed_nodes"),
"N1": gb.ItemSet(torch.LongTensor([0, 2, 1]), names="seed_nodes"),
"N0": gb.ItemSet(torch.LongTensor([1, 0, 2]), names="seeds"),
"N1": gb.ItemSet(torch.LongTensor([0, 2, 1]), names="seeds"),
}
)
batch_size = 2
......
......@@ -47,7 +47,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
items = torch.arange(3)
else:
items = torch.tensor([2, 0, 1])
names = "seed_nodes"
names = "seeds"
itemset = gb.ItemSet(items, names=names)
graph = get_hetero_graph().to(F.ctx())
if hetero:
......@@ -94,9 +94,7 @@ def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch):
).to(F.ctx())
torch.random.set_rng_state(torch.manual_seed(123).get_state())
batch_dependency = 100
itemset = gb.ItemSet(
torch.zeros(batch_dependency + 1).int(), names="seed_nodes"
)
itemset = gb.ItemSet(torch.zeros(batch_dependency + 1).int(), names="seeds")
datapipe = gb.ItemSampler(itemset, batch_size=1).copy_to(F.ctx())
fanouts = [5, 5]
datapipe = datapipe.sample_layer_neighbor(
......
......@@ -25,7 +25,7 @@ def test_FeatureFetcher_invoke():
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
......@@ -58,7 +58,7 @@ def test_FeatureFetcher_homo():
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
......@@ -104,7 +104,7 @@ def test_FeatureFetcher_with_edges_homo():
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
......@@ -152,8 +152,8 @@ def test_FeatureFetcher_hetero():
itemset = gb.ItemSetDict(
{
"n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seed_nodes"),
"n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seed_nodes"),
"n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seeds"),
"n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seeds"),
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
......@@ -215,7 +215,7 @@ def test_FeatureFetcher_with_edges_hetero():
itemset = gb.ItemSetDict(
{
"n1": gb.ItemSet(torch.randint(0, 20, (10,)), names="seed_nodes"),
"n1": gb.ItemSet(torch.randint(0, 20, (10,)), names="seeds"),
}
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
......
......@@ -1161,7 +1161,7 @@ def test_DistributedItemSampler(
):
nprocs = 4
batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes")
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds")
# On Windows, if the process group initialization file already exists,
# the program may hang. So we need to delete it if it exists.
......
......@@ -8,15 +8,15 @@ from dgl import graphbolt as gb
def test_ItemSet_names():
# ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
assert item_set.names == ("seed_nodes",)
item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
assert item_set.names == ("seeds",)
# ItemSet with multiple names.
item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
)
assert item_set.names == ("seed_nodes", "labels")
assert item_set.names == ("seeds", "labels")
# ItemSet without name.
item_set = gb.ItemSet(torch.arange(0, 5))
......@@ -27,19 +27,19 @@ def test_ItemSet_names():
AssertionError,
match=re.escape("Number of items (1) and names (2) must match."),
):
_ = gb.ItemSet(5, names=("seed_nodes", "labels"))
_ = gb.ItemSet(5, names=("seeds", "labels"))
# ItemSet with mismatched items and names.
with pytest.raises(
AssertionError,
match=re.escape("Number of items (1) and names (2) must match."),
):
_ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels"))
_ = gb.ItemSet(torch.arange(0, 5), names=("seeds", "labels"))
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64])
def test_ItemSet_scalar_dtype(dtype):
item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seed_nodes")
item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seeds")
for i, item in enumerate(item_set):
assert i == item
assert item.dtype == dtype
......@@ -106,8 +106,8 @@ def test_ItemSet_length():
def test_ItemSet_seed_nodes():
# Node IDs with tensor.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
assert item_set.names == ("seed_nodes",)
item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
assert item_set.names == ("seeds",)
# Iterating over ItemSet and indexing one by one.
for i, item in enumerate(item_set):
assert i == item.item()
......@@ -118,8 +118,8 @@ def test_ItemSet_seed_nodes():
assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))
# Node IDs with single integer.
item_set = gb.ItemSet(5, names="seed_nodes")
assert item_set.names == ("seed_nodes",)
item_set = gb.ItemSet(5, names="seeds")
assert item_set.names == ("seeds",)
# Iterating over ItemSet and indexing one by one.
for i, item in enumerate(item_set):
assert i == item.item()
......@@ -145,8 +145,8 @@ def test_ItemSet_seed_nodes_labels():
# Node IDs and labels.
seed_nodes = torch.arange(0, 5)
labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
assert item_set.names == ("seed_nodes", "labels")
item_set = gb.ItemSet((seed_nodes, labels), names=("seeds", "labels"))
assert item_set.names == ("seeds", "labels")
# Iterating over ItemSet and indexing one by one.
for i, (seed_node, label) in enumerate(item_set):
assert seed_node == seed_nodes[i]
......@@ -164,8 +164,8 @@ def test_ItemSet_seed_nodes_labels():
def test_ItemSet_node_pairs():
# Node pairs.
node_pairs = torch.arange(0, 10).reshape(-1, 2)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
assert item_set.names == ("node_pairs",)
item_set = gb.ItemSet(node_pairs, names="seeds")
assert item_set.names == ("seeds",)
# Iterating over ItemSet and indexing one by one.
for i, (src, dst) in enumerate(item_set):
assert node_pairs[i][0] == src
......@@ -182,8 +182,8 @@ def test_ItemSet_node_pairs_labels():
# Node pairs and labels
node_pairs = torch.arange(0, 10).reshape(-1, 2)
labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
assert item_set.names == ("node_pairs", "labels")
item_set = gb.ItemSet((node_pairs, labels), names=("seeds", "labels"))
assert item_set.names == ("seeds", "labels")
# Iterating over ItemSet and indexing one by one.
for i, (node_pair, label) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair)
......@@ -198,26 +198,31 @@ def test_ItemSet_node_pairs_labels():
assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
def test_ItemSet_node_pairs_neg_dsts():
def test_ItemSet_node_pairs_labels_indexes():
# Node pairs and negative destinations.
node_pairs = torch.arange(0, 10).reshape(-1, 2)
neg_dsts = torch.arange(10, 25).reshape(-1, 3)
labels = torch.tensor([1, 1, 0, 0, 0])
indexes = torch.tensor([0, 1, 0, 0, 1])
item_set = gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
(node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
)
assert item_set.names == ("node_pairs", "negative_dsts")
assert item_set.names == ("seeds", "labels", "indexes")
# Iterating over ItemSet and indexing one by one.
for i, (node_pair, neg_dst) in enumerate(item_set):
for i, (node_pair, label, index) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair)
assert torch.equal(neg_dsts[i], neg_dst)
assert torch.equal(labels[i], label)
assert torch.equal(indexes[i], index)
assert torch.equal(node_pairs[i], item_set[i][0])
assert torch.equal(neg_dsts[i], item_set[i][1])
assert torch.equal(labels[i], item_set[i][1])
assert torch.equal(indexes[i], item_set[i][2])
# Indexing with a slice.
assert torch.equal(item_set[:][0], node_pairs)
assert torch.equal(item_set[:][1], neg_dsts)
assert torch.equal(item_set[:][1], labels)
assert torch.equal(item_set[:][2], indexes)
# Indexing with an Iterable.
assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
assert torch.equal(item_set[torch.arange(0, 5)][1], neg_dsts)
assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
assert torch.equal(item_set[torch.arange(0, 5)][2], indexes)
def test_ItemSet_graphs():
......@@ -237,26 +242,26 @@ def test_ItemSetDict_names():
# ItemSetDict with single name.
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
"user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
"item": gb.ItemSet(torch.arange(5, 10), names="seeds"),
}
)
assert item_set.names == ("seed_nodes",)
assert item_set.names == ("seeds",)
# ItemSetDict with multiple names.
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
),
"item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
),
}
)
assert item_set.names == ("seed_nodes", "labels")
assert item_set.names == ("seeds", "labels")
# ItemSetDict with no name.
item_set = gb.ItemSetDict(
......@@ -276,11 +281,9 @@ def test_ItemSetDict_names():
{
"user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
),
"item": gb.ItemSet(
(torch.arange(5, 10),), names=("seed_nodes",)
names=("seeds", "labels"),
),
"item": gb.ItemSet((torch.arange(5, 10),), names=("seeds",)),
}
)
......@@ -354,14 +357,14 @@ def test_ItemSetDict_iteration_seed_nodes():
user_ids = torch.arange(0, 5)
item_ids = torch.arange(5, 10)
ids = {
"user": gb.ItemSet(user_ids, names="seed_nodes"),
"item": gb.ItemSet(item_ids, names="seed_nodes"),
"user": gb.ItemSet(user_ids, names="seeds"),
"item": gb.ItemSet(item_ids, names="seeds"),
}
chained_ids = []
for key, value in ids.items():
chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids)
assert item_set.names == ("seed_nodes",)
assert item_set.names == ("seeds",)
# Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set):
assert len(item) == 1
......@@ -413,18 +416,14 @@ def test_ItemSetDict_iteration_seed_nodes_labels():
item_ids = torch.arange(5, 10)
item_labels = torch.randint(0, 3, (5,))
ids_labels = {
"user": gb.ItemSet(
(user_ids, user_labels), names=("seed_nodes", "labels")
),
"item": gb.ItemSet(
(item_ids, item_labels), names=("seed_nodes", "labels")
),
"user": gb.ItemSet((user_ids, user_labels), names=("seeds", "labels")),
"item": gb.ItemSet((item_ids, item_labels), names=("seeds", "labels")),
}
chained_ids = []
for key, value in ids_labels.items():
chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids_labels)
assert item_set.names == ("seed_nodes", "labels")
assert item_set.names == ("seeds", "labels")
# Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set):
assert len(item) == 1
......@@ -443,14 +442,14 @@ def test_ItemSetDict_iteration_node_pairs():
# Node pairs.
node_pairs = torch.arange(0, 10).reshape(-1, 2)
node_pairs_dict = {
"user:like:item": gb.ItemSet(node_pairs, names="node_pairs"),
"user:follow:user": gb.ItemSet(node_pairs, names="node_pairs"),
"user:like:item": gb.ItemSet(node_pairs, names="seeds"),
"user:follow:user": gb.ItemSet(node_pairs, names="seeds"),
}
expected_data = []
for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_dict)
assert item_set.names == ("node_pairs",)
assert item_set.names == ("seeds",)
# Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set):
assert len(item) == 1
......@@ -471,17 +470,17 @@ def test_ItemSetDict_iteration_node_pairs_labels():
labels = torch.randint(0, 3, (5,))
node_pairs_labels = {
"user:like:item": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels")
(node_pairs, labels), names=("seeds", "labels")
),
"user:follow:user": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels")
(node_pairs, labels), names=("seeds", "labels")
),
}
expected_data = []
for key, value in node_pairs_labels.items():
expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_labels)
assert item_set.names == ("node_pairs", "labels")
assert item_set.names == ("seeds", "labels")
# Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set):
assert len(item) == 1
......@@ -501,23 +500,24 @@ def test_ItemSetDict_iteration_node_pairs_labels():
assert torch.equal(item_set[:]["user:follow:user"][1], labels)
def test_ItemSetDict_iteration_node_pairs_neg_dsts():
def test_ItemSetDict_iteration_node_pairs_labels_indexes():
# Node pairs and negative destinations.
node_pairs = torch.arange(0, 10).reshape(-1, 2)
neg_dsts = torch.arange(10, 25).reshape(-1, 3)
labels = torch.tensor([1, 1, 0, 0, 0])
indexes = torch.tensor([0, 1, 0, 0, 1])
node_pairs_neg_dsts = {
"user:like:item": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
(node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
),
"user:follow:user": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
(node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
),
}
expected_data = []
for key, value in node_pairs_neg_dsts.items():
expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_neg_dsts)
assert item_set.names == ("node_pairs", "negative_dsts")
assert item_set.names == ("seeds", "labels", "indexes")
# Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set):
assert len(item) == 1
......@@ -526,24 +526,28 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
assert key in item
assert torch.equal(item[key][0], value[0])
assert torch.equal(item[key][1], value[1])
assert torch.equal(item[key][2], value[2])
assert item_set[i].keys() == item.keys()
key = list(item.keys())[0]
assert torch.equal(item_set[i][key][0], item[key][0])
assert torch.equal(item_set[i][key][1], item[key][1])
assert torch.equal(item_set[i][key][2], item[key][2])
# Indexing with a slice.
assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
assert torch.equal(item_set[:]["user:like:item"][1], neg_dsts)
assert torch.equal(item_set[:]["user:like:item"][1], labels)
assert torch.equal(item_set[:]["user:like:item"][2], indexes)
assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
assert torch.equal(item_set[:]["user:follow:user"][1], neg_dsts)
assert torch.equal(item_set[:]["user:follow:user"][1], labels)
assert torch.equal(item_set[:]["user:follow:user"][2], indexes)
def test_ItemSet_repr():
# ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
expected_str = (
"ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n"
" names=('seeds',),\n"
")"
)
......@@ -552,12 +556,12 @@ def test_ItemSet_repr():
# ItemSet with multiple names.
item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
)
expected_str = (
"ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n"
" names=('seeds', 'labels'),\n"
")"
)
assert str(item_set) == expected_str, item_set
......@@ -567,20 +571,20 @@ def test_ItemSetDict_repr():
# ItemSetDict with single name.
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
"user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
"item": gb.ItemSet(torch.arange(5, 10), names="seeds"),
}
)
expected_str = (
"ItemSetDict(\n"
" itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n"
" names=('seeds',),\n"
" ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n"
" names=('seeds',),\n"
" )},\n"
" names=('seed_nodes',),\n"
" names=('seeds',),\n"
")"
)
assert str(item_set) == expected_str, item_set
......@@ -590,11 +594,11 @@ def test_ItemSetDict_repr():
{
"user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
),
"item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
),
}
)
......@@ -602,12 +606,12 @@ def test_ItemSetDict_repr():
"ItemSetDict(\n"
" itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n"
" names=('seeds', 'labels'),\n"
" ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n"
" names=('seed_nodes', 'labels'),\n"
" names=('seeds', 'labels'),\n"
" )},\n"
" names=('seed_nodes', 'labels'),\n"
" names=('seeds', 'labels'),\n"
")"
)
assert str(item_set) == expected_str, item_set
......@@ -563,11 +563,10 @@ def test_dgl_link_predication_homo():
check_dgl_blocks_homo(minibatch, dgl_blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_dgl_link_predication_hetero(mode):
def test_dgl_link_predication_hetero():
# Arrange
minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = {
minibatch.compacted_seeds = {
relation: (torch.tensor([[1, 1, 2, 0, 1, 2], [1, 0, 1, 1, 0, 0]]).T,),
reverse_relation: (
torch.tensor([[0, 1, 1, 2, 0, 2], [1, 0, 1, 1, 0, 0]]).T,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment