Unverified Commit fa17fd09 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] update tests of ItemSet to cover canonical cases (#6272)

parent ac49220c
...@@ -9,16 +9,17 @@ from torch.testing import assert_close ...@@ -9,16 +9,17 @@ from torch.testing import assert_close
def test_ItemSet_names(): def test_ItemSet_names():
# ItemSet with single name. # ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_node") item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
assert item_set.names == ("seed_node",) assert item_set.names == ("seed_nodes",)
# ItemSet with multiple names. # ItemSet with multiple names.
item_set = gb.ItemSet( item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), names=("seed_node", "label") (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
) )
assert item_set.names == ("seed_node", "label") assert item_set.names == ("seed_nodes", "labels")
# ItemSet with no name. # ItemSet without name.
item_set = gb.ItemSet(torch.arange(0, 5)) item_set = gb.ItemSet(torch.arange(0, 5))
assert item_set.names is None assert item_set.names is None
...@@ -27,33 +28,120 @@ def test_ItemSet_names(): ...@@ -27,33 +28,120 @@ def test_ItemSet_names():
AssertionError, AssertionError,
match=re.escape("Number of items (1) and names (2) must match."), match=re.escape("Number of items (1) and names (2) must match."),
): ):
_ = gb.ItemSet(torch.arange(0, 5), names=("seed_node", "label")) _ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels"))
def test_ItemSet_length():
# Single iterable with valid length.
ids = torch.arange(0, 5)
item_set = gb.ItemSet(ids)
assert len(item_set) == 5
# Tuple of iterables with valid length.
item_set = gb.ItemSet((torch.arange(0, 5), torch.arange(5, 10)))
assert len(item_set) == 5
class InvalidLength:
def __iter__(self):
return iter([0, 1, 2])
# Single iterable with invalid length.
item_set = gb.ItemSet(InvalidLength())
with pytest.raises(TypeError):
_ = len(item_set)
# Tuple of iterables with invalid length.
item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
with pytest.raises(TypeError):
_ = len(item_set)
def test_ItemSet_iteration_seed_nodes():
# Node IDs.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
assert item_set.names == ("seed_nodes",)
for i, item in enumerate(item_set):
assert i == item.item()
def test_ItemSet_iteration_seed_nodes_labels():
# Node IDs and labels.
seed_nodes = torch.arange(0, 5)
labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
assert item_set.names == ("seed_nodes", "labels")
for i, (seed_node, label) in enumerate(item_set):
assert seed_node == seed_nodes[i]
assert label == labels[i]
def test_ItemSet_iteration_node_pairs():
# Node pairs.
node_pairs = torch.arange(0, 10).reshape(-1, 2)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
assert item_set.names == ("node_pairs",)
for i, (src, dst) in enumerate(item_set):
assert node_pairs[i][0] == src
assert node_pairs[i][1] == dst
def test_ItemSet_iteration_node_pairs_labels():
# Node pairs and labels
node_pairs = torch.arange(0, 10).reshape(-1, 2)
labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
assert item_set.names == ("node_pairs", "labels")
for i, (node_pair, label) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair)
assert labels[i] == label
def test_ItemSet_iteration_node_pairs_neg_dsts():
# Node pairs and negative destinations.
node_pairs = torch.arange(0, 10).reshape(-1, 2)
neg_dsts = torch.arange(10, 25).reshape(-1, 3)
item_set = gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "neg_dsts")
)
assert item_set.names == ("node_pairs", "neg_dsts")
for i, (node_pair, neg_dst) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair)
assert torch.equal(neg_dsts[i], neg_dst)
def test_ItemSet_iteration_graphs():
# Graphs.
graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
item_set = gb.ItemSet(graphs)
assert item_set.names is None
for i, item in enumerate(item_set):
assert graphs[i] == item
def test_ItemSetDict_names(): def test_ItemSetDict_names():
# ItemSetDict with single name. # ItemSetDict with single name.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user": gb.ItemSet(torch.arange(0, 5), names="seed_node"), "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_node"), "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
} }
) )
assert item_set.names == ("seed_node",) assert item_set.names == ("seed_nodes",)
# ItemSetDict with multiple names. # ItemSetDict with multiple names.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user": gb.ItemSet( "user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_node", "label"), names=("seed_nodes", "labels"),
), ),
"item": gb.ItemSet( "item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)), (torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_node", "label"), names=("seed_nodes", "labels"),
), ),
} }
) )
assert item_set.names == ("seed_node", "label") assert item_set.names == ("seed_nodes", "labels")
# ItemSetDict with no name. # ItemSetDict with no name.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
...@@ -73,45 +161,17 @@ def test_ItemSetDict_names(): ...@@ -73,45 +161,17 @@ def test_ItemSetDict_names():
{ {
"user": gb.ItemSet( "user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_node", "label"), names=("seed_nodes", "labels"),
), ),
"item": gb.ItemSet( "item": gb.ItemSet(
(torch.arange(5, 10),), names=("seed_node",) (torch.arange(5, 10),), names=("seed_nodes",)
), ),
} }
) )
def test_ItemSet_valid_length(): def test_ItemSetDict_length():
# Single iterable. # Single iterable with valid length.
ids = torch.arange(0, 5)
item_set = gb.ItemSet(ids)
assert len(item_set) == 5
# Tuple of iterables.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
item_set = gb.ItemSet(node_pairs)
assert len(item_set) == 5
def test_ItemSet_invalid_length():
class InvalidLength:
def __iter__(self):
return iter([0, 1, 2])
# Single iterable.
item_set = gb.ItemSet(InvalidLength())
with pytest.raises(TypeError):
_ = len(item_set)
# Tuple of iterables.
item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
with pytest.raises(TypeError):
_ = len(item_set)
def test_ItemSetDict_valid_length():
# Single iterable.
user_ids = torch.arange(0, 5) user_ids = torch.arange(0, 5)
item_ids = torch.arange(0, 5) item_ids = torch.arange(0, 5)
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
...@@ -122,24 +182,26 @@ def test_ItemSetDict_valid_length(): ...@@ -122,24 +182,26 @@ def test_ItemSetDict_valid_length():
) )
assert len(item_set) == len(user_ids) + len(item_ids) assert len(item_set) == len(user_ids) + len(item_ids)
# Tuple of iterables. # Tuple of iterables with valid length.
like = (torch.arange(0, 5), torch.arange(0, 5)) node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
follow = (torch.arange(0, 5), torch.arange(5, 10)) neg_dsts_like = torch.arange(10, 20).reshape(-1, 2)
node_pairs_follow = torch.arange(0, 10).reshape(-1, 2)
neg_dsts_follow = torch.arange(10, 20).reshape(-1, 2)
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user:like:item": gb.ItemSet(like), "user:like:item": gb.ItemSet((node_pairs_like, neg_dsts_like)),
"user:follow:user": gb.ItemSet(follow), "user:follow:user": gb.ItemSet(
(node_pairs_follow, neg_dsts_follow)
),
} }
) )
assert len(item_set) == len(like[0]) + len(follow[0]) assert len(item_set) == node_pairs_like.size(0) + node_pairs_follow.size(0)
def test_ItemSetDict_invalid_length():
class InvalidLength: class InvalidLength:
def __iter__(self): def __iter__(self):
return iter([0, 1, 2]) return iter([0, 1, 2])
# Single iterable. # Single iterable with invalid length.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user": gb.ItemSet(InvalidLength()), "user": gb.ItemSet(InvalidLength()),
...@@ -149,7 +211,7 @@ def test_ItemSetDict_invalid_length(): ...@@ -149,7 +211,7 @@ def test_ItemSetDict_invalid_length():
with pytest.raises(TypeError): with pytest.raises(TypeError):
_ = len(item_set) _ = len(item_set)
# Tuple of iterables. # Tuple of iterables with invalid length.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())), "user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())),
...@@ -160,63 +222,19 @@ def test_ItemSetDict_invalid_length(): ...@@ -160,63 +222,19 @@ def test_ItemSetDict_invalid_length():
_ = len(item_set) _ = len(item_set)
def test_ItemSet_node_edge_ids(): def test_ItemSetDict_iteration_seed_nodes():
# Node or edge IDs. # Node IDs.
item_set = gb.ItemSet(torch.arange(0, 5)) user_ids = torch.arange(0, 5)
for i, item in enumerate(item_set): item_ids = torch.arange(5, 10)
assert i == item.item()
def test_ItemSet_graphs():
# Graphs.
graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
item_set = gb.ItemSet(graphs)
for i, item in enumerate(item_set):
assert graphs[i] == item
def test_ItemSet_node_pairs():
# Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
item_set = gb.ItemSet(node_pairs)
for i, (src, dst) in enumerate(item_set):
assert node_pairs[0][i] == src
assert node_pairs[1][i] == dst
def test_ItemSet_node_pairs_labels():
# Node pairs and labels
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((node_pairs[0], node_pairs[1], labels))
for i, (src, dst, label) in enumerate(item_set):
assert node_pairs[0][i] == src
assert node_pairs[1][i] == dst
assert labels[i] == label
def test_ItemSet_head_tail_neg_tails():
# Head, tail and negative tails.
heads = torch.arange(0, 5)
tails = torch.arange(5, 10)
neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = gb.ItemSet((heads, tails, neg_tails))
for i, (head, tail, negs) in enumerate(item_set):
assert heads[i] == head
assert tails[i] == tail
assert_close(neg_tails[i], negs)
def test_ItemSetDict_node_edge_ids():
# Node or edge IDs
ids = { ids = {
"user:like:item": gb.ItemSet(torch.arange(0, 5)), "user": gb.ItemSet(user_ids, names="seed_nodes"),
"user:follow:user": gb.ItemSet(torch.arange(0, 5)), "item": gb.ItemSet(item_ids, names="seed_nodes"),
} }
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
chained_ids += [(key, v) for v in value] chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids) item_set = gb.ItemSetDict(ids)
assert item_set.names == ("seed_nodes",)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
...@@ -224,59 +242,98 @@ def test_ItemSetDict_node_edge_ids(): ...@@ -224,59 +242,98 @@ def test_ItemSetDict_node_edge_ids():
assert item[chained_ids[i][0]] == chained_ids[i][1] assert item[chained_ids[i][0]] == chained_ids[i][1]
def test_ItemSetDict_node_pairs(): def test_ItemSetDict_iteration_seed_nodes_labels():
# Node IDs and labels.
user_ids = torch.arange(0, 5)
user_labels = torch.randint(0, 3, (5,))
item_ids = torch.arange(5, 10)
item_labels = torch.randint(0, 3, (5,))
ids_labels = {
"user": gb.ItemSet(
(user_ids, user_labels), names=("seed_nodes", "labels")
),
"item": gb.ItemSet(
(item_ids, item_labels), names=("seed_nodes", "labels")
),
}
chained_ids = []
for key, value in ids_labels.items():
chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids_labels)
assert item_set.names == ("seed_nodes", "labels")
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
assert chained_ids[i][0] in item
assert item[chained_ids[i][0]] == chained_ids[i][1]
def test_ItemSetDict_iteration_node_pairs():
# Node pairs. # Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = torch.arange(0, 10).reshape(-1, 2)
node_pairs_dict = { node_pairs_dict = {
"user:like:item": gb.ItemSet(node_pairs), "user:like:item": gb.ItemSet(node_pairs, names="node_pairs"),
"user:follow:user": gb.ItemSet(node_pairs), "user:follow:user": gb.ItemSet(node_pairs, names="node_pairs"),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
assert item_set.names == ("node_pairs",)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
assert expected_data[i][0] in item assert expected_data[i][0] in item
assert item[expected_data[i][0]] == expected_data[i][1] assert torch.equal(item[expected_data[i][0]], expected_data[i][1])
def test_ItemSetDict_node_pairs_labels(): def test_ItemSetDict_iteration_node_pairs_labels():
# Node pairs and labels # Node pairs and labels
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = torch.arange(0, 10).reshape(-1, 2)
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
node_pairs_dict = { node_pairs_labels = {
"user:like:item": gb.ItemSet((node_pairs[0], node_pairs[1], labels)), "user:like:item": gb.ItemSet(
"user:follow:user": gb.ItemSet((node_pairs[0], node_pairs[1], labels)), (node_pairs, labels), names=("node_pairs", "labels")
),
"user:follow:user": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels")
),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_labels.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_labels)
assert item_set.names == ("node_pairs", "labels")
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
assert expected_data[i][0] in item key, value = expected_data[i]
assert item[expected_data[i][0]] == expected_data[i][1] assert key in item
assert torch.equal(item[key][0], value[0])
assert item[key][1] == value[1]
def test_ItemSetDict_head_tail_neg_tails():
# Head, tail and negative tails.
heads = torch.arange(0, 5) def test_ItemSetDict_iteration_node_pairs_neg_dsts():
tails = torch.arange(5, 10) # Node pairs and negative destinations.
neg_tails = torch.arange(10, 20).reshape(5, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
item_set = gb.ItemSet((heads, tails, neg_tails)) neg_dsts = torch.arange(10, 25).reshape(-1, 3)
data_dict = { node_pairs_neg_dsts = {
"user:like:item": gb.ItemSet((heads, tails, neg_tails)), "user:like:item": gb.ItemSet(
"user:follow:user": gb.ItemSet((heads, tails, neg_tails)), (node_pairs, neg_dsts), names=("node_pairs", "neg_dsts")
),
"user:follow:user": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "neg_dsts")
),
} }
expected_data = [] expected_data = []
for key, value in data_dict.items(): for key, value in node_pairs_neg_dsts.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(data_dict) item_set = gb.ItemSetDict(node_pairs_neg_dsts)
assert item_set.names == ("node_pairs", "neg_dsts")
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
assert expected_data[i][0] in item key, value = expected_data[i]
assert_close(item[expected_data[i][0]], expected_data[i][1]) assert key in item
assert torch.equal(item[key][0], value[0])
assert torch.equal(item[key][1], value[1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment