Unverified Commit 2144a3ce authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] use str instead of tuple for canonical etype (#6216)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 4c5489e8
...@@ -100,14 +100,14 @@ class ItemSetDict: ...@@ -100,14 +100,14 @@ class ItemSetDict:
>>> node_pairs_like = (torch.arange(0, 2), torch.arange(0, 2)) >>> node_pairs_like = (torch.arange(0, 2), torch.arange(0, 2))
>>> node_pairs_follow = (torch.arange(0, 3), torch.arange(3, 6)) >>> node_pairs_follow = (torch.arange(0, 3), torch.arange(3, 6))
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... ('user', 'like', 'item'): gb.ItemSet(node_pairs_like), ... "user:like:item": gb.ItemSet(node_pairs_like),
... ('user', 'follow', 'user'): gb.ItemSet(node_pairs_follow)}) ... "user:follow:user": gb.ItemSet(node_pairs_follow)})
>>> list(item_set) >>> list(item_set)
[{('user', 'like', 'item'): (tensor(0), tensor(0))}, [{"user:like:item": (tensor(0), tensor(0))},
{('user', 'like', 'item'): (tensor(1), tensor(1))}, {"user:like:item": (tensor(1), tensor(1))},
{('user', 'follow', 'user'): (tensor(0), tensor(3))}, {"user:follow:user": (tensor(0), tensor(3))},
{('user', 'follow', 'user'): (tensor(1), tensor(4))}, {"user:follow:user": (tensor(1), tensor(4))},
{('user', 'follow', 'user'): (tensor(2), tensor(5))}] {"user:follow:user": (tensor(2), tensor(5))}]
3. Tuple of iterables with different shape. 3. Tuple of iterables with different shape.
>>> like = (torch.arange(0, 2), torch.arange(0, 2), >>> like = (torch.arange(0, 2), torch.arange(0, 2),
...@@ -115,14 +115,14 @@ class ItemSetDict: ...@@ -115,14 +115,14 @@ class ItemSetDict:
>>> follow = (torch.arange(0, 3), torch.arange(3, 6), >>> follow = (torch.arange(0, 3), torch.arange(3, 6),
... torch.arange(0, 6).reshape(-1, 2)) ... torch.arange(0, 6).reshape(-1, 2))
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... ('user', 'like', 'item'): gb.ItemSet(like), ... "user:like:item": gb.ItemSet(like),
... ('user', 'follow', 'user'): gb.ItemSet(follow)}) ... "user:follow:user": gb.ItemSet(follow)})
>>> list(item_set) >>> list(item_set)
[{('user', 'like', 'item'): (tensor(0), tensor(0), tensor([0, 1]))}, [{"user:like:item": (tensor(0), tensor(0), tensor([0, 1]))},
{('user', 'like', 'item'): (tensor(1), tensor(1), tensor([2, 3]))}, {"user:like:item": (tensor(1), tensor(1), tensor([2, 3]))},
{('user', 'follow', 'user'): (tensor(0), tensor(3), tensor([0, 1]))}, {"user:follow:user": (tensor(0), tensor(3), tensor([0, 1]))},
{('user', 'follow', 'user'): (tensor(1), tensor(4), tensor([2, 3]))}, {"user:follow:user": (tensor(1), tensor(4), tensor([2, 3]))},
{('user', 'follow', 'user'): (tensor(2), tensor(5), tensor([4, 5]))}] {"user:follow:user": (tensor(2), tensor(5), tensor([4, 5]))}]
""" """
def __init__(self, itemsets: Dict[str, ItemSet]) -> None: def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
......
...@@ -117,15 +117,15 @@ class MinibatchSampler(IterDataPipe): ...@@ -117,15 +117,15 @@ class MinibatchSampler(IterDataPipe):
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5)) >>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5))
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12)) >>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12))
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(node_pairs_like), ... "user:like:item": gb.ItemSet(node_pairs_like),
... ("user", "follow", "user"): gb.ItemSet(node_pairs_follow), ... "user:follow:user": gb.ItemSet(node_pairs_follow),
... }) ... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4) >>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler) >>> list(minibatch_sampler)
[{('user', 'like', 'item'): [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]}, [{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4])], {"user:like:item": [tensor([4]), tensor([4])],
('user', 'follow', 'user'): [tensor([0, 1, 2]), tensor([6, 7, 8])]}, "user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]},
{('user', 'follow', 'user'): [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}] {"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}]
9. Heterogeneous node pairs and labels. 9. Heterogeneous node pairs and labels.
>>> like = ( >>> like = (
...@@ -133,17 +133,17 @@ class MinibatchSampler(IterDataPipe): ...@@ -133,17 +133,17 @@ class MinibatchSampler(IterDataPipe):
>>> follow = ( >>> follow = (
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6)) ... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6))
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(like), ... "user:like:item": gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow), ... "user:follow:user": gb.ItemSet(follow),
... }) ... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4) >>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler) >>> list(minibatch_sampler)
[{('user', 'like', 'item'): [{"user:like:item":
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]}, [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4]), tensor([4])], {"user:like:item": [tensor([4]), tensor([4]), tensor([4])],
('user', 'follow', 'user'): "user:follow:user":
[tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]}, [tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]},
{('user', 'follow', 'user'): {"user:follow:user":
[tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}] [tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}]
10. Heterogeneous head, tail and negative tails. 10. Heterogeneous head, tail and negative tails.
...@@ -154,17 +154,17 @@ class MinibatchSampler(IterDataPipe): ...@@ -154,17 +154,17 @@ class MinibatchSampler(IterDataPipe):
... torch.arange(0, 6), torch.arange(6, 12), ... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2)) ... torch.arange(12, 24).reshape(-1, 2))
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(like), ... "user:like:item": gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow), ... "user:follow:user": gb.ItemSet(follow),
... }) ... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4) >>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler) >>> list(minibatch_sampler)
[{('user', 'like', 'item'): [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), [{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]}, tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4]), tensor([[13, 14]])], {"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])],
('user', 'follow', 'user'): [tensor([0, 1, 2]), tensor([6, 7, 8]), "user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8]),
tensor([[12, 13], [14, 15], [16, 17]])]}, tensor([[12, 13], [14, 15], [16, 17]])]},
{('user', 'follow', 'user'): [tensor([3, 4, 5]), tensor([ 9, 10, 11]), {"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11]),
tensor([[18, 19], [20, 21], [22, 23]])]}] tensor([[18, 19], [20, 21], [22, 23]])]}]
""" """
......
...@@ -50,8 +50,8 @@ def test_ItemSetDict_valid_length(): ...@@ -50,8 +50,8 @@ def test_ItemSetDict_valid_length():
follow = (torch.arange(0, 5), torch.arange(5, 10)) follow = (torch.arange(0, 5), torch.arange(5, 10))
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
("user", "like", "item"): gb.ItemSet(like), "user:like:item": gb.ItemSet(like),
("user", "follow", "user"): gb.ItemSet(follow), "user:follow:user": gb.ItemSet(follow),
} }
) )
assert len(item_set) == len(like[0]) + len(follow[0]) assert len(item_set) == len(like[0]) + len(follow[0])
...@@ -75,12 +75,8 @@ def test_ItemSetDict_invalid_length(): ...@@ -75,12 +75,8 @@ def test_ItemSetDict_invalid_length():
# Tuple of iterables. # Tuple of iterables.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
("user", "like", "item"): gb.ItemSet( "user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())),
(InvalidLength(), InvalidLength()) "user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())),
),
("user", "follow", "user"): gb.ItemSet(
(InvalidLength(), InvalidLength())
),
} }
) )
with pytest.raises(TypeError): with pytest.raises(TypeError):
...@@ -137,8 +133,8 @@ def test_ItemSet_head_tail_neg_tails(): ...@@ -137,8 +133,8 @@ def test_ItemSet_head_tail_neg_tails():
def test_ItemSetDict_node_edge_ids(): def test_ItemSetDict_node_edge_ids():
# Node or edge IDs # Node or edge IDs
ids = { ids = {
("user", "like", "item"): gb.ItemSet(torch.arange(0, 5)), "user:like:item": gb.ItemSet(torch.arange(0, 5)),
("user", "follow", "user"): gb.ItemSet(torch.arange(0, 5)), "user:follow:user": gb.ItemSet(torch.arange(0, 5)),
} }
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
...@@ -155,8 +151,8 @@ def test_ItemSetDict_node_pairs(): ...@@ -155,8 +151,8 @@ def test_ItemSetDict_node_pairs():
# Node pairs. # Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
node_pairs_dict = { node_pairs_dict = {
("user", "like", "item"): gb.ItemSet(node_pairs), "user:like:item": gb.ItemSet(node_pairs),
("user", "follow", "user"): gb.ItemSet(node_pairs), "user:follow:user": gb.ItemSet(node_pairs),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
...@@ -174,12 +170,8 @@ def test_ItemSetDict_node_pairs_labels(): ...@@ -174,12 +170,8 @@ def test_ItemSetDict_node_pairs_labels():
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
node_pairs_dict = { node_pairs_dict = {
("user", "like", "item"): gb.ItemSet( "user:like:item": gb.ItemSet((node_pairs[0], node_pairs[1], labels)),
(node_pairs[0], node_pairs[1], labels) "user:follow:user": gb.ItemSet((node_pairs[0], node_pairs[1], labels)),
),
("user", "follow", "user"): gb.ItemSet(
(node_pairs[0], node_pairs[1], labels)
),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
...@@ -199,8 +191,8 @@ def test_ItemSetDict_head_tail_neg_tails(): ...@@ -199,8 +191,8 @@ def test_ItemSetDict_head_tail_neg_tails():
neg_tails = torch.arange(10, 20).reshape(5, 2) neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = gb.ItemSet((heads, tails, neg_tails)) item_set = gb.ItemSet((heads, tails, neg_tails))
data_dict = { data_dict = {
("user", "like", "item"): gb.ItemSet((heads, tails, neg_tails)), "user:like:item": gb.ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): gb.ItemSet((heads, tails, neg_tails)), "user:follow:user": gb.ItemSet((heads, tails, neg_tails)),
} }
expected_data = [] expected_data = []
for key, value in data_dict.items(): for key, value in data_dict.items():
......
...@@ -266,8 +266,8 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last): ...@@ -266,8 +266,8 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
torch.arange(num_ids * 3, num_ids * 4), torch.arange(num_ids * 3, num_ids * 4),
) )
node_pairs_dict = { node_pairs_dict = {
("user", "like", "item"): gb.ItemSet(node_pairs_0), "user:like:item": gb.ItemSet(node_pairs_0),
("user", "follow", "user"): gb.ItemSet(node_pairs_1), "user:follow:user": gb.ItemSet(node_pairs_1),
} }
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler( minibatch_sampler = gb.MinibatchSampler(
...@@ -319,10 +319,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -319,10 +319,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
) )
labels = torch.arange(0, num_ids) labels = torch.arange(0, num_ids)
node_pairs_dict = { node_pairs_dict = {
("user", "like", "item"): gb.ItemSet( "user:like:item": gb.ItemSet(
(node_pairs_0[0], node_pairs_0[1], labels) (node_pairs_0[0], node_pairs_0[1], labels)
), ),
("user", "follow", "user"): gb.ItemSet( "user:follow:user": gb.ItemSet(
(node_pairs_1[0], node_pairs_1[1], labels + num_ids * 2) (node_pairs_1[0], node_pairs_1[1], labels + num_ids * 2)
), ),
} }
...@@ -380,8 +380,8 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last): ...@@ -380,8 +380,8 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
tails = torch.arange(num_ids, num_ids * 2) tails = torch.arange(num_ids, num_ids * 2)
neg_tails = torch.stack((heads + 1, heads + 2), dim=-1) neg_tails = torch.stack((heads + 1, heads + 2), dim=-1)
data_dict = { data_dict = {
("user", "like", "item"): gb.ItemSet((heads, tails, neg_tails)), "user:like:item": gb.ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): gb.ItemSet((heads, tails, neg_tails)), "user:follow:user": gb.ItemSet((heads, tails, neg_tails)),
} }
item_set = gb.ItemSetDict(data_dict) item_set = gb.ItemSetDict(data_dict)
minibatch_sampler = gb.MinibatchSampler( minibatch_sampler = gb.MinibatchSampler(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment