"tests/python/vscode:/vscode.git/clone" did not exist on "20e5e26697cdd8d3d8228d1e1071f16b43a24ee4"
Unverified Commit 557c0a86 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable indexing ItemSetDict in ItemSampler (#6468)

parent 09c33b9f
......@@ -120,6 +120,52 @@ class ItemShufflerAndBatcher:
(self._buffer_size + batch_size - 1) // batch_size * batch_size
)
def _collate_batch(self, buffer, indices, offsets=None):
"""Collate a batch from the buffer. For internal use only."""
if isinstance(buffer, torch.Tensor):
# For item set that's initialized with integer or single tensor,
# `buffer` is a tensor.
return buffer[indices]
elif isinstance(buffer, list) and isinstance(buffer[0], DGLGraph):
# For item set that's initialized with a list of
# DGLGraphs, `buffer` is a list of DGLGraphs.
return dgl_batch([buffer[idx] for idx in indices])
elif isinstance(buffer, tuple):
# For item set that's initialized with a tuple of items,
# `buffer` is a tuple of tensors.
return tuple(item[indices] for item in buffer)
elif isinstance(buffer, Mapping):
# For item set that's initialized with a dict of items,
# `buffer` is a dict of tensors/lists/tuples.
keys = list(buffer.keys())
key_indices = torch.searchsorted(offsets, indices, right=True) - 1
batch = {}
for j, key in enumerate(keys):
mask = (key_indices == j).nonzero().squeeze(1)
if len(mask) == 0:
continue
batch[key] = self._collate_batch(
buffer[key], indices[mask] - offsets[j]
)
return batch
raise TypeError(f"Unsupported buffer type {type(buffer).__name__}.")
def _calculate_offsets(self, buffer):
"""Calculate offsets for each item in buffer. For internal use only."""
if not isinstance(buffer, Mapping):
return None
offsets = [0]
for value in buffer.values():
if isinstance(value, torch.Tensor):
offsets.append(offsets[-1] + len(value))
elif isinstance(value, tuple):
offsets.append(offsets[-1] + len(value[0]))
else:
raise TypeError(
f"Unsupported buffer type {type(value).__name__}."
)
return torch.tensor(offsets)
def __iter__(self):
buffer = None
num_items = len(self._item_set)
......@@ -130,26 +176,12 @@ class ItemShufflerAndBatcher:
indices = torch.arange(end - start)
if self._shuffle:
np.random.shuffle(indices.numpy())
offsets = self._calculate_offsets(buffer)
for i in range(0, len(indices), self._batch_size):
if self._drop_last and i + self._batch_size > len(indices):
break
batch_indices = indices[i : i + self._batch_size]
if isinstance(self._item_set._items, int):
# For integer-initialized item set, `buffer` is a tensor.
yield buffer[batch_indices]
elif len(self._item_set._items) == 1:
if isinstance(buffer[0], DGLGraph):
# For item set that's initialized with a list of
# DGLGraphs, `buffer` is a list of DGLGraphs.
yield dgl_batch([buffer[idx] for idx in batch_indices])
else:
# For item set that's initialized with a single
# tensor, `buffer` is a tensor.
yield buffer[batch_indices]
else:
# For item set that's initialized with a tuple of items,
# `buffer` is a tuple of tensors.
yield tuple(item[batch_indices] for item in buffer)
yield self._collate_batch(buffer, batch_indices, offsets)
buffer = None
start = end
......@@ -375,8 +407,6 @@ class ItemSampler(IterDataPipe):
item_set[0]
except TypeError:
use_indexing = False
# [TODO][Rui] For now, we disable indexing for ItemSetDict.
use_indexing = (not isinstance(item_set, ItemSetDict)) and use_indexing
self._use_indexing = use_indexing
self._item_set = (
item_set if self._use_indexing else IterableWrapper(item_set)
......
......@@ -388,6 +388,52 @@ def test_append_with_other_datapipes():
assert len(data) == batch_size
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last):
class IterableOnly:
def __init__(self, start, stop):
self._start = start
self._stop = stop
def __iter__(self):
return iter(torch.arange(self._start, self._stop))
num_ids = 205
ids = {
"user": gb.ItemSet(IterableOnly(0, 99), names="seed_nodes"),
"item": gb.ItemSet(IterableOnly(99, num_ids), names="seed_nodes"),
}
chained_ids = []
for key, value in ids.items():
chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids)
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
else:
assert False
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
ids = []
for _, v in minibatch.seed_nodes.items():
ids.append(v)
ids = torch.cat(ids)
assert len(ids) == expected_batch_size
minibatch_ids.append(ids)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment