Unverified Commit 3774945d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable slice for integer-init ItemSet (#6457)

parent b08c446d
......@@ -134,12 +134,21 @@ class ItemShufflerAndBatcher:
if self._drop_last and i + self._batch_size > len(indices):
break
batch_indices = indices[i : i + self._batch_size]
if len(self._item_set._items) == 1:
if isinstance(self._item_set._items, int):
# For integer-initialized item set, `buffer` is a tensor.
yield buffer[batch_indices]
elif len(self._item_set._items) == 1:
if isinstance(buffer[0], DGLGraph):
# For item set that's initialized with a list of
# DGLGraphs, `buffer` is a list of DGLGraphs.
yield dgl_batch([buffer[idx] for idx in batch_indices])
else:
# For item set that's initialized with a single
# tensor, `buffer` is a tensor.
yield buffer[batch_indices]
else:
# For item set that's initialized with a tuple of items,
# `buffer` is a tuple of tensors.
yield tuple(item[batch_indices] for item in buffer)
buffer = None
start = end
......
......@@ -37,7 +37,7 @@ class ItemSet:
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5),
tensor(6), tensor(7), tensor(8), tensor(9)]
>>> item_set[torch.arange(0, num)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> item_set.names
('seed_nodes',)
......@@ -151,12 +151,20 @@ class ItemSet:
f"{type(self).__name__} instance doesn't support indexing."
)
if isinstance(self._items, int):
assert isinstance(idx, (int, torch.Tensor)), (
f"Indexing of integer-initialized {type(self).__name__} "
f"instance must be int or torch.Tensor."
if isinstance(idx, slice):
start, stop, step = idx.indices(self._items)
return torch.arange(start, stop, step)
if isinstance(idx, int):
if idx < 0:
idx += self._items
if idx < 0 or idx >= self._items:
raise IndexError(
f"{type(self).__name__} index out of range."
)
return idx
raise TypeError(
f"{type(self).__name__} indices must be integer or slice."
)
# [Warning] Index range is not checked.
return idx
if len(self._items) == 1:
return self._items[0][idx]
return tuple(item[idx] for item in self._items)
......
......@@ -98,6 +98,34 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_integer(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 103
item_set = gb.ItemSet(num_ids, names="seed_nodes")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size
else:
if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch.seed_nodes)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
......
......@@ -114,16 +114,20 @@ def test_ItemSet_seed_nodes():
assert i == item.item()
assert i == item_set[i]
# Indexing with a slice.
assert torch.equal(item_set[:], torch.arange(0, 5))
# Indexing with an integer.
assert item_set[0] == 0
assert item_set[-1] == 4
# Indexing that is out of range.
with pytest.raises(IndexError, match="ItemSet index out of range."):
_ = item_set[5]
with pytest.raises(IndexError, match="ItemSet index out of range."):
_ = item_set[-10]
# Indexing with tensor.
with pytest.raises(
AssertionError,
match=(
"Indexing of integer-initialized ItemSet instance must be int or "
"torch.Tensor."
),
TypeError, match="ItemSet indices must be integer or slice."
):
_ = item_set[:]
# Indexing with an Tensor.
assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))
_ = item_set[torch.arange(3)]
def test_ItemSet_seed_nodes_labels():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment