"vscode:/vscode.git/clone" did not exist on "2fedcdc2feda37b56fcf3198a4eac79ba8ae9a2d"
Unverified Commit 2144a3ce authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] use str instead of tuple for canonical etype (#6216)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 4c5489e8
......@@ -100,14 +100,14 @@ class ItemSetDict:
>>> node_pairs_like = (torch.arange(0, 2), torch.arange(0, 2))
>>> node_pairs_follow = (torch.arange(0, 3), torch.arange(3, 6))
>>> item_set = gb.ItemSetDict({
... ('user', 'like', 'item'): gb.ItemSet(node_pairs_like),
... ('user', 'follow', 'user'): gb.ItemSet(node_pairs_follow)})
... "user:like:item": gb.ItemSet(node_pairs_like),
... "user:follow:user": gb.ItemSet(node_pairs_follow)})
>>> list(item_set)
[{('user', 'like', 'item'): (tensor(0), tensor(0))},
{('user', 'like', 'item'): (tensor(1), tensor(1))},
{('user', 'follow', 'user'): (tensor(0), tensor(3))},
{('user', 'follow', 'user'): (tensor(1), tensor(4))},
{('user', 'follow', 'user'): (tensor(2), tensor(5))}]
[{"user:like:item": (tensor(0), tensor(0))},
{"user:like:item": (tensor(1), tensor(1))},
{"user:follow:user": (tensor(0), tensor(3))},
{"user:follow:user": (tensor(1), tensor(4))},
{"user:follow:user": (tensor(2), tensor(5))}]
3. Tuple of iterables with different shape.
>>> like = (torch.arange(0, 2), torch.arange(0, 2),
......@@ -115,14 +115,14 @@ class ItemSetDict:
>>> follow = (torch.arange(0, 3), torch.arange(3, 6),
... torch.arange(0, 6).reshape(-1, 2))
>>> item_set = gb.ItemSetDict({
... ('user', 'like', 'item'): gb.ItemSet(like),
... ('user', 'follow', 'user'): gb.ItemSet(follow)})
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow)})
>>> list(item_set)
[{('user', 'like', 'item'): (tensor(0), tensor(0), tensor([0, 1]))},
{('user', 'like', 'item'): (tensor(1), tensor(1), tensor([2, 3]))},
{('user', 'follow', 'user'): (tensor(0), tensor(3), tensor([0, 1]))},
{('user', 'follow', 'user'): (tensor(1), tensor(4), tensor([2, 3]))},
{('user', 'follow', 'user'): (tensor(2), tensor(5), tensor([4, 5]))}]
[{"user:like:item": (tensor(0), tensor(0), tensor([0, 1]))},
{"user:like:item": (tensor(1), tensor(1), tensor([2, 3]))},
{"user:follow:user": (tensor(0), tensor(3), tensor([0, 1]))},
{"user:follow:user": (tensor(1), tensor(4), tensor([2, 3]))},
{"user:follow:user": (tensor(2), tensor(5), tensor([4, 5]))}]
"""
def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
......
......@@ -117,15 +117,15 @@ class MinibatchSampler(IterDataPipe):
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5))
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12))
>>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(node_pairs_like),
... ("user", "follow", "user"): gb.ItemSet(node_pairs_follow),
... "user:like:item": gb.ItemSet(node_pairs_like),
... "user:follow:user": gb.ItemSet(node_pairs_follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{('user', 'like', 'item'): [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4])],
('user', 'follow', 'user'): [tensor([0, 1, 2]), tensor([6, 7, 8])]},
{('user', 'follow', 'user'): [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}]
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]},
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}]
9. Heterogeneous node pairs and labels.
>>> like = (
......@@ -133,17 +133,17 @@ class MinibatchSampler(IterDataPipe):
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6))
>>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow),
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{('user', 'like', 'item'):
[{"user:like:item":
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4]), tensor([4])],
('user', 'follow', 'user'):
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])],
"user:follow:user":
[tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]},
{('user', 'follow', 'user'):
{"user:follow:user":
[tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}]
10. Heterogeneous head, tail and negative tails.
......@@ -154,17 +154,17 @@ class MinibatchSampler(IterDataPipe):
... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2))
>>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow),
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{('user', 'like', 'item'): [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4]), tensor([[13, 14]])],
('user', 'follow', 'user'): [tensor([0, 1, 2]), tensor([6, 7, 8]),
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8]),
tensor([[12, 13], [14, 15], [16, 17]])]},
{('user', 'follow', 'user'): [tensor([3, 4, 5]), tensor([ 9, 10, 11]),
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11]),
tensor([[18, 19], [20, 21], [22, 23]])]}]
"""
......
......@@ -50,8 +50,8 @@ def test_ItemSetDict_valid_length():
follow = (torch.arange(0, 5), torch.arange(5, 10))
item_set = gb.ItemSetDict(
{
("user", "like", "item"): gb.ItemSet(like),
("user", "follow", "user"): gb.ItemSet(follow),
"user:like:item": gb.ItemSet(like),
"user:follow:user": gb.ItemSet(follow),
}
)
assert len(item_set) == len(like[0]) + len(follow[0])
......@@ -75,12 +75,8 @@ def test_ItemSetDict_invalid_length():
# Tuple of iterables.
item_set = gb.ItemSetDict(
{
("user", "like", "item"): gb.ItemSet(
(InvalidLength(), InvalidLength())
),
("user", "follow", "user"): gb.ItemSet(
(InvalidLength(), InvalidLength())
),
"user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())),
"user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())),
}
)
with pytest.raises(TypeError):
......@@ -137,8 +133,8 @@ def test_ItemSet_head_tail_neg_tails():
def test_ItemSetDict_node_edge_ids():
# Node or edge IDs
ids = {
("user", "like", "item"): gb.ItemSet(torch.arange(0, 5)),
("user", "follow", "user"): gb.ItemSet(torch.arange(0, 5)),
"user:like:item": gb.ItemSet(torch.arange(0, 5)),
"user:follow:user": gb.ItemSet(torch.arange(0, 5)),
}
chained_ids = []
for key, value in ids.items():
......@@ -155,8 +151,8 @@ def test_ItemSetDict_node_pairs():
# Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
node_pairs_dict = {
("user", "like", "item"): gb.ItemSet(node_pairs),
("user", "follow", "user"): gb.ItemSet(node_pairs),
"user:like:item": gb.ItemSet(node_pairs),
"user:follow:user": gb.ItemSet(node_pairs),
}
expected_data = []
for key, value in node_pairs_dict.items():
......@@ -174,12 +170,8 @@ def test_ItemSetDict_node_pairs_labels():
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,))
node_pairs_dict = {
("user", "like", "item"): gb.ItemSet(
(node_pairs[0], node_pairs[1], labels)
),
("user", "follow", "user"): gb.ItemSet(
(node_pairs[0], node_pairs[1], labels)
),
"user:like:item": gb.ItemSet((node_pairs[0], node_pairs[1], labels)),
"user:follow:user": gb.ItemSet((node_pairs[0], node_pairs[1], labels)),
}
expected_data = []
for key, value in node_pairs_dict.items():
......@@ -199,8 +191,8 @@ def test_ItemSetDict_head_tail_neg_tails():
neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = gb.ItemSet((heads, tails, neg_tails))
data_dict = {
("user", "like", "item"): gb.ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): gb.ItemSet((heads, tails, neg_tails)),
"user:like:item": gb.ItemSet((heads, tails, neg_tails)),
"user:follow:user": gb.ItemSet((heads, tails, neg_tails)),
}
expected_data = []
for key, value in data_dict.items():
......
......@@ -266,8 +266,8 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
torch.arange(num_ids * 3, num_ids * 4),
)
node_pairs_dict = {
("user", "like", "item"): gb.ItemSet(node_pairs_0),
("user", "follow", "user"): gb.ItemSet(node_pairs_1),
"user:like:item": gb.ItemSet(node_pairs_0),
"user:follow:user": gb.ItemSet(node_pairs_1),
}
item_set = gb.ItemSetDict(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler(
......@@ -319,10 +319,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
)
labels = torch.arange(0, num_ids)
node_pairs_dict = {
("user", "like", "item"): gb.ItemSet(
"user:like:item": gb.ItemSet(
(node_pairs_0[0], node_pairs_0[1], labels)
),
("user", "follow", "user"): gb.ItemSet(
"user:follow:user": gb.ItemSet(
(node_pairs_1[0], node_pairs_1[1], labels + num_ids * 2)
),
}
......@@ -380,8 +380,8 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
tails = torch.arange(num_ids, num_ids * 2)
neg_tails = torch.stack((heads + 1, heads + 2), dim=-1)
data_dict = {
("user", "like", "item"): gb.ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): gb.ItemSet((heads, tails, neg_tails)),
"user:like:item": gb.ItemSet((heads, tails, neg_tails)),
"user:follow:user": gb.ItemSet((heads, tails, neg_tails)),
}
item_set = gb.ItemSetDict(data_dict)
minibatch_sampler = gb.MinibatchSampler(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment