Unverified Commit c8ec9ce3 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable indexing on ItemSet instance (#6439)

parent a548805d
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
from typing import Dict, Iterable, Iterator, Sized, Tuple, Union from typing import Dict, Iterable, Iterator, Sized, Tuple, Union
import torch
__all__ = ["ItemSet", "ItemSetDict"] __all__ = ["ItemSet", "ItemSetDict"]
...@@ -33,7 +35,10 @@ class ItemSet: ...@@ -33,7 +35,10 @@ class ItemSet:
>>> num = 10 >>> num = 10
>>> item_set = gb.ItemSet(num, names="seed_nodes") >>> item_set = gb.ItemSet(num, names="seed_nodes")
>>> list(item_set) >>> list(item_set)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5),
tensor(6), tensor(7), tensor(8), tensor(9)]
>>> item_set[torch.arange(0, num)]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> item_set.names >>> item_set.names
('seed_nodes',) ('seed_nodes',)
...@@ -42,6 +47,8 @@ class ItemSet: ...@@ -42,6 +47,8 @@ class ItemSet:
>>> item_set = gb.ItemSet(node_ids, names="seed_nodes") >>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
>>> list(item_set) >>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)] [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4])
>>> item_set.names >>> item_set.names
('seed_nodes',) ('seed_nodes',)
...@@ -53,6 +60,8 @@ class ItemSet: ...@@ -53,6 +60,8 @@ class ItemSet:
>>> list(item_set) >>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)), [(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
(tensor(3), tensor(8)), (tensor(4), tensor(9))] (tensor(3), tensor(8)), (tensor(4), tensor(9))]
>>> item_set[:]
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))
>>> item_set.names >>> item_set.names
('seed_nodes', 'labels') ('seed_nodes', 'labels')
...@@ -67,6 +76,10 @@ class ItemSet: ...@@ -67,6 +76,10 @@ class ItemSet:
(tensor([4, 5]), tensor([16, 17, 18])), (tensor([4, 5]), tensor([16, 17, 18])),
(tensor([6, 7]), tensor([19, 20, 21])), (tensor([6, 7]), tensor([19, 20, 21])),
(tensor([8, 9]), tensor([22, 23, 24]))] (tensor([8, 9]), tensor([22, 23, 24]))]
>>> item_set[:]
(tensor([[0, 1], [2, 3], [4, 5], [6, 7],[8, 9]]),
tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21],
[22, 23, 24]]))
>>> item_set.names >>> item_set.names
('node_pairs', 'negative_dsts') ('node_pairs', 'negative_dsts')
""" """
...@@ -76,33 +89,20 @@ class ItemSet: ...@@ -76,33 +89,20 @@ class ItemSet:
items: Union[int, Iterable, Tuple[Iterable]], items: Union[int, Iterable, Tuple[Iterable]],
names: Union[str, Tuple[str]] = None, names: Union[str, Tuple[str]] = None,
) -> None: ) -> None:
# Initiated by an integer. if isinstance(items, (int, tuple)):
if isinstance(items, int):
self._items = items
if names is not None:
if isinstance(names, tuple):
self._names = names
else:
self._names = (names,)
assert (
len(self._names) == 1
), "Number of names mustn't exceed 1 when item is an integer."
else:
self._names = None
return
# Otherwise.
if isinstance(items, tuple):
self._items = items self._items = items
else: else:
self._items = (items,) self._items = (items,)
if names is not None: if names is not None:
num_items = (
len(self._items) if isinstance(self._items, tuple) else 1
)
if isinstance(names, tuple): if isinstance(names, tuple):
self._names = names self._names = names
else: else:
self._names = (names,) self._names = (names,)
assert len(self._items) == len(self._names), ( assert num_items == len(self._names), (
f"Number of items ({len(self._items)}) and " f"Number of items ({num_items}) and "
f"names ({len(self._names)}) must match." f"names ({len(self._names)}) must match."
) )
else: else:
...@@ -110,7 +110,7 @@ class ItemSet: ...@@ -110,7 +110,7 @@ class ItemSet:
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
if isinstance(self._items, int): if isinstance(self._items, int):
yield from range(self._items) yield from torch.arange(self._items)
return return
if len(self._items) == 1: if len(self._items) == 1:
...@@ -143,6 +143,24 @@ class ItemSet: ...@@ -143,6 +143,24 @@ class ItemSet:
f"{type(self).__name__} instance doesn't have valid length." f"{type(self).__name__} instance doesn't have valid length."
) )
def __getitem__(self, idx: Union[int, slice, Iterable]) -> Tuple:
try:
len(self)
except TypeError:
raise TypeError(
f"{type(self).__name__} instance doesn't support indexing."
)
if isinstance(self._items, int):
assert isinstance(idx, (int, torch.Tensor)), (
f"Indexing of integer-initialized {type(self).__name__} "
f"instance must be int or torch.Tensor."
)
# [Warning] Index range is not checked.
return idx
if len(self._items) == 1:
return self._items[0][idx]
return tuple(item[idx] for item in self._items)
@property @property
def names(self) -> Tuple[str]: def names(self) -> Tuple[str]:
"""Return the names of the items.""" """Return the names of the items."""
......
...@@ -26,9 +26,7 @@ def test_ItemSet_names(): ...@@ -26,9 +26,7 @@ def test_ItemSet_names():
# Integer-initiated ItemSet with excessive names. # Integer-initiated ItemSet with excessive names.
with pytest.raises( with pytest.raises(
AssertionError, AssertionError,
match=re.escape( match=re.escape("Number of items (1) and names (2) must match."),
"Number of names mustn't exceed 1 when item is an integer."
),
): ):
_ = gb.ItemSet(5, names=("seed_nodes", "labels")) _ = gb.ItemSet(5, names=("seed_nodes", "labels"))
...@@ -69,61 +67,123 @@ def test_ItemSet_length(): ...@@ -69,61 +67,123 @@ def test_ItemSet_length():
# Single iterable with invalid length. # Single iterable with invalid length.
item_set = gb.ItemSet(InvalidLength()) item_set = gb.ItemSet(InvalidLength())
with pytest.raises(TypeError): with pytest.raises(
TypeError, match="ItemSet instance doesn't have valid length."
):
_ = len(item_set) _ = len(item_set)
with pytest.raises(
TypeError, match="ItemSet instance doesn't support indexing."
):
_ = item_set[0]
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item assert i == item
# Tuple of iterables with invalid length. # Tuple of iterables with invalid length.
item_set = gb.ItemSet((InvalidLength(), InvalidLength())) item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
with pytest.raises(TypeError): with pytest.raises(
TypeError, match="ItemSet instance doesn't have valid length."
):
_ = len(item_set) _ = len(item_set)
with pytest.raises(
TypeError, match="ItemSet instance doesn't support indexing."
):
_ = item_set[0]
for i, (item1, item2) in enumerate(item_set): for i, (item1, item2) in enumerate(item_set):
assert i == item1 assert i == item1
assert i == item2 assert i == item2
def test_ItemSet_iteration_seed_nodes(): def test_ItemSet_seed_nodes():
# Node IDs. # Node IDs with tensor.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seed_nodes",)
# Iterating over ItemSet and indexing one by one.
for i, item in enumerate(item_set):
assert i == item.item()
assert i == item_set[i]
# Indexing with a slice.
assert torch.equal(item_set[:], torch.arange(0, 5))
# Indexing with an Iterable.
assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))
# Node IDs with single integer.
item_set = gb.ItemSet(5, names="seed_nodes")
assert item_set.names == ("seed_nodes",)
# Iterating over ItemSet and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item.item() assert i == item.item()
assert i == item_set[i]
# Indexing with a slice.
with pytest.raises(
AssertionError,
match=(
"Indexing of integer-initialized ItemSet instance must be int or "
"torch.Tensor."
),
):
_ = item_set[:]
# Indexing with an Tensor.
assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))
def test_ItemSet_iteration_seed_nodes_labels(): def test_ItemSet_seed_nodes_labels():
# Node IDs and labels. # Node IDs and labels.
seed_nodes = torch.arange(0, 5) seed_nodes = torch.arange(0, 5)
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels")) item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
assert item_set.names == ("seed_nodes", "labels") assert item_set.names == ("seed_nodes", "labels")
# Iterating over ItemSet and indexing one by one.
for i, (seed_node, label) in enumerate(item_set): for i, (seed_node, label) in enumerate(item_set):
assert seed_node == seed_nodes[i] assert seed_node == seed_nodes[i]
assert label == labels[i] assert label == labels[i]
assert seed_node == item_set[i][0]
assert label == item_set[i][1]
# Indexing with a slice.
assert torch.equal(item_set[:][0], seed_nodes)
assert torch.equal(item_set[:][1], labels)
# Indexing with an Iterable.
assert torch.equal(item_set[torch.arange(0, 5)][0], seed_nodes)
assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
def test_ItemSet_iteration_node_pairs(): def test_ItemSet_node_pairs():
# Node pairs. # Node pairs.
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
item_set = gb.ItemSet(node_pairs, names="node_pairs") item_set = gb.ItemSet(node_pairs, names="node_pairs")
assert item_set.names == ("node_pairs",) assert item_set.names == ("node_pairs",)
# Iterating over ItemSet and indexing one by one.
for i, (src, dst) in enumerate(item_set): for i, (src, dst) in enumerate(item_set):
assert node_pairs[i][0] == src assert node_pairs[i][0] == src
assert node_pairs[i][1] == dst assert node_pairs[i][1] == dst
assert node_pairs[i][0] == item_set[i][0]
assert node_pairs[i][1] == item_set[i][1]
# Indexing with a slice.
assert torch.equal(item_set[:], node_pairs)
# Indexing with an Iterable.
assert torch.equal(item_set[torch.arange(0, 5)], node_pairs)
def test_ItemSet_iteration_node_pairs_labels(): def test_ItemSet_node_pairs_labels():
# Node pairs and labels # Node pairs and labels
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels")) item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
assert item_set.names == ("node_pairs", "labels") assert item_set.names == ("node_pairs", "labels")
# Iterating over ItemSet and indexing one by one.
for i, (node_pair, label) in enumerate(item_set): for i, (node_pair, label) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair) assert torch.equal(node_pairs[i], node_pair)
assert labels[i] == label assert labels[i] == label
assert torch.equal(node_pairs[i], item_set[i][0])
assert labels[i] == item_set[i][1]
# Indexing with a slice.
assert torch.equal(item_set[:][0], node_pairs)
assert torch.equal(item_set[:][1], labels)
# Indexing with an Iterable.
assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
def test_ItemSet_iteration_node_pairs_neg_dsts(): def test_ItemSet_node_pairs_neg_dsts():
# Node pairs and negative destinations. # Node pairs and negative destinations.
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
neg_dsts = torch.arange(10, 25).reshape(-1, 3) neg_dsts = torch.arange(10, 25).reshape(-1, 3)
...@@ -131,18 +191,31 @@ def test_ItemSet_iteration_node_pairs_neg_dsts(): ...@@ -131,18 +191,31 @@ def test_ItemSet_iteration_node_pairs_neg_dsts():
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
) )
assert item_set.names == ("node_pairs", "negative_dsts") assert item_set.names == ("node_pairs", "negative_dsts")
# Iterating over ItemSet and indexing one by one.
for i, (node_pair, neg_dst) in enumerate(item_set): for i, (node_pair, neg_dst) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair) assert torch.equal(node_pairs[i], node_pair)
assert torch.equal(neg_dsts[i], neg_dst) assert torch.equal(neg_dsts[i], neg_dst)
assert torch.equal(node_pairs[i], item_set[i][0])
assert torch.equal(neg_dsts[i], item_set[i][1])
# Indexing with a slice.
assert torch.equal(item_set[:][0], node_pairs)
assert torch.equal(item_set[:][1], neg_dsts)
# Indexing with an Iterable.
assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
assert torch.equal(item_set[torch.arange(0, 5)][1], neg_dsts)
def test_ItemSet_iteration_graphs(): def test_ItemSet_graphs():
# Graphs. # Graphs.
graphs = [dgl.rand_graph(10, 20) for _ in range(5)] graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
item_set = gb.ItemSet(graphs) item_set = gb.ItemSet(graphs)
assert item_set.names is None assert item_set.names is None
# Iterating over ItemSet and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert graphs[i] == item assert graphs[i] == item
assert graphs[i] == item_set[i]
# Indexing with a slice.
assert item_set[:] == graphs
def test_ItemSetDict_names(): def test_ItemSetDict_names():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment