"vscode:/vscode.git/clone" did not exist on "df55f0535829e2a1348062f85c7483fb51197102"
Unverified Commit 9d5b897a authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable indexing for ItemSetDict (#6459)

parent 8b37564b
......@@ -202,6 +202,8 @@ class ItemSetDict:
{"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)},
{"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)},
{"item": tensor(9)}}]
>>> item_set[:]
{"user": tensor([0, 1, 2, 3, 4]), "item": tensor([5, 6, 7, 8, 9])}
>>> item_set.names
('seed_nodes',)
......@@ -222,6 +224,9 @@ class ItemSetDict:
[{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))},
{"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))},
{"item": (tensor(4), tensor(4))}}]
>>> item_set[:]
{"user": (tensor([0, 1]), tensor([0, 1])),
"item": (tensor([2, 3, 4]), tensor([2, 3, 4]))}
>>> item_set.names
('seed_nodes', 'labels')
......@@ -244,6 +249,13 @@ class ItemSetDict:
{"user:follow:user": (tensor([0, 1]), tensor([ 6, 7, 8, 9, 10, 11]))},
{"user:follow:user": (tensor([2, 3]), tensor([12, 13, 14, 15, 16, 17]))},
{"user:follow:user": (tensor([4, 5]), tensor([18, 19, 20, 21, 22, 23]))}]
>>> item_set[:]
{"user:like:item": (tensor([[0, 1], [2, 3]]),
tensor([[4, 5, 6], [7, 8, 9]])),
"user:follow:user": (tensor([[0, 1], [2, 3], [4, 5]]),
tensor([[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]]))}
>>> item_set.names
('node_pairs', 'negative_dsts')
"""
......@@ -254,6 +266,15 @@ class ItemSetDict:
assert all(
self._names == itemset.names for itemset in itemsets.values()
), "All itemsets must have the same names."
try:
# For indexable itemsets, we compute the offsets for each itemset
# in advance to speed up indexing.
offsets = [0] + [
len(itemset) for itemset in self._itemsets.values()
]
self._offsets = torch.tensor(offsets).cumsum(0)
except TypeError:
self._offsets = None
def __iter__(self) -> Iterator:
for key, itemset in self._itemsets.items():
......@@ -263,6 +284,43 @@ class ItemSetDict:
def __len__(self) -> int:
return sum(len(itemset) for itemset in self._itemsets.values())
def __getitem__(self, idx: Union[int, slice]) -> Dict[str, Tuple]:
if self._offsets is None:
raise TypeError(
f"{type(self).__name__} instance doesn't support indexing."
)
total_num = self._offsets[-1]
if isinstance(idx, int):
if idx < 0:
idx += total_num
if idx < 0 or idx >= total_num:
raise IndexError(f"{type(self).__name__} index out of range.")
offset_idx = torch.searchsorted(self._offsets, idx, right=True)
offset_idx -= 1
idx -= self._offsets[offset_idx]
key = list(self._itemsets.keys())[offset_idx]
return {key: self._itemsets[key][idx]}
elif isinstance(idx, slice):
start, stop, step = idx.indices(total_num)
assert step == 1, "Step must be 1."
assert start < stop, "Start must be smaller than stop."
data = {}
offset_idx_start = max(
1, torch.searchsorted(self._offsets, start, right=False)
)
keys = list(self._itemsets.keys())
for offset_idx in range(offset_idx_start, len(self._offsets)):
key = keys[offset_idx - 1]
data[key] = self._itemsets[key][
max(0, start - self._offsets[offset_idx - 1]) : stop
- self._offsets[offset_idx - 1]
]
if stop <= self._offsets[offset_idx]:
break
return data
raise TypeError(f"{type(self).__name__} indices must be int or slice.")
@property
def names(self) -> Tuple[str]:
"""Return the names of the items."""
......
......@@ -312,8 +312,14 @@ def test_ItemSetDict_length():
"item": gb.ItemSet(InvalidLength()),
}
)
with pytest.raises(TypeError):
with pytest.raises(
TypeError, match="ItemSet instance doesn't have valid length."
):
_ = len(item_set)
with pytest.raises(
TypeError, match="ItemSetDict instance doesn't support indexing."
):
_ = item_set[0]
# Tuple of iterables with invalid length.
item_set = gb.ItemSetDict(
......@@ -322,8 +328,14 @@ def test_ItemSetDict_length():
"user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())),
}
)
with pytest.raises(TypeError):
with pytest.raises(
TypeError, match="ItemSet instance doesn't have valid length."
):
_ = len(item_set)
with pytest.raises(
TypeError, match="ItemSetDict instance doesn't support indexing."
):
_ = item_set[0]
def test_ItemSetDict_iteration_seed_nodes():
......@@ -339,11 +351,48 @@ def test_ItemSetDict_iteration_seed_nodes():
chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids)
assert item_set.names == ("seed_nodes",)
# Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
assert chained_ids[i][0] in item
assert item[chained_ids[i][0]] == chained_ids[i][1]
assert item_set[i] == item
assert item_set[i - len(item_set)] == item
# Indexing all with a slice.
assert torch.equal(item_set[:]["user"], user_ids)
assert torch.equal(item_set[:]["item"], item_ids)
# Indexing partial with a slice.
partial_data = item_set[:3]
assert len(list(partial_data.keys())) == 1
assert torch.equal(partial_data["user"], user_ids[:3])
partial_data = item_set[7:]
assert len(list(partial_data.keys())) == 1
assert torch.equal(partial_data["item"], item_ids[2:])
partial_data = item_set[3:7]
assert len(list(partial_data.keys())) == 2
assert torch.equal(partial_data["user"], user_ids[3:5])
assert torch.equal(partial_data["item"], item_ids[:2])
# Exception cases.
with pytest.raises(AssertionError, match="Step must be 1."):
_ = item_set[::2]
with pytest.raises(
AssertionError, match="Start must be smaller than stop."
):
_ = item_set[5:3]
with pytest.raises(
AssertionError, match="Start must be smaller than stop."
):
_ = item_set[-1:3]
with pytest.raises(IndexError, match="ItemSetDict index out of range."):
_ = item_set[20]
with pytest.raises(IndexError, match="ItemSetDict index out of range."):
_ = item_set[-20]
with pytest.raises(
TypeError, match="ItemSetDict indices must be int or slice."
):
_ = item_set[torch.arange(3)]
def test_ItemSetDict_iteration_seed_nodes_labels():
......@@ -365,11 +414,18 @@ def test_ItemSetDict_iteration_seed_nodes_labels():
chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids_labels)
assert item_set.names == ("seed_nodes", "labels")
# Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
assert chained_ids[i][0] in item
assert item[chained_ids[i][0]] == chained_ids[i][1]
assert item_set[i] == item
# Indexing with a slice.
assert torch.equal(item_set[:]["user"][0], user_ids)
assert torch.equal(item_set[:]["user"][1], user_labels)
assert torch.equal(item_set[:]["item"][0], item_ids)
assert torch.equal(item_set[:]["item"][1], item_labels)
def test_ItemSetDict_iteration_node_pairs():
......@@ -384,11 +440,18 @@ def test_ItemSetDict_iteration_node_pairs():
expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_dict)
assert item_set.names == ("node_pairs",)
# Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
assert expected_data[i][0] in item
assert torch.equal(item[expected_data[i][0]], expected_data[i][1])
assert item_set[i].keys() == item.keys()
key = list(item.keys())[0]
assert torch.equal(item_set[i][key], item[key])
# Indexing with a slice.
assert torch.equal(item_set[:]["user:like:item"], node_pairs)
assert torch.equal(item_set[:]["user:follow:user"], node_pairs)
def test_ItemSetDict_iteration_node_pairs_labels():
......@@ -408,6 +471,7 @@ def test_ItemSetDict_iteration_node_pairs_labels():
expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_labels)
assert item_set.names == ("node_pairs", "labels")
# Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......@@ -415,6 +479,15 @@ def test_ItemSetDict_iteration_node_pairs_labels():
assert key in item
assert torch.equal(item[key][0], value[0])
assert item[key][1] == value[1]
assert item_set[i].keys() == item.keys()
key = list(item.keys())[0]
assert torch.equal(item_set[i][key][0], item[key][0])
assert torch.equal(item_set[i][key][1], item[key][1])
# Indexing with a slice.
assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
assert torch.equal(item_set[:]["user:like:item"][1], labels)
assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
assert torch.equal(item_set[:]["user:follow:user"][1], labels)
def test_ItemSetDict_iteration_node_pairs_neg_dsts():
......@@ -434,6 +507,7 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_neg_dsts)
assert item_set.names == ("node_pairs", "negative_dsts")
# Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......@@ -441,3 +515,12 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
assert key in item
assert torch.equal(item[key][0], value[0])
assert torch.equal(item[key][1], value[1])
assert item_set[i].keys() == item.keys()
key = list(item.keys())[0]
assert torch.equal(item_set[i][key][0], item[key][0])
assert torch.equal(item_set[i][key][1], item[key][1])
# Indexing with a slice.
assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
assert torch.equal(item_set[:]["user:like:item"][1], neg_dsts)
assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
assert torch.equal(item_set[:]["user:follow:user"][1], neg_dsts)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment