Unverified Commit 3df6e301 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Update `__len__()` and `__getitem__()` of `ItemSet` (#7253)

parent 8e3d8101
"""GraphBolt Itemset."""
import textwrap
from typing import Dict, Iterable, Iterator, Sized, Tuple, Union
from typing import Dict, Iterable, Iterator, Tuple, Union
import torch
......@@ -119,20 +119,35 @@ class ItemSet:
items: Union[int, torch.Tensor, Iterable, Tuple[Iterable]],
names: Union[str, Tuple[str]] = None,
) -> None:
if isinstance(items, tuple) or is_scalar(items):
if is_scalar(items):
self._length = int(items)
self._items = items
self._num_items = 1
elif isinstance(items, tuple):
try:
self._length = len(items[0])
except TypeError:
self._length = None
if self._length is not None:
if any(self._length != len(item) for item in items):
raise ValueError("Size mismatch between items.")
self._items = items
self._num_items = len(items)
else:
try:
self._length = len(items)
except TypeError:
self._length = None
self._items = (items,)
self._num_items = 1
if names is not None:
num_items = (
len(self._items) if isinstance(self._items, tuple) else 1
)
if isinstance(names, tuple):
self._names = names
else:
self._names = (names,)
assert num_items == len(self._names), (
f"Number of items ({num_items}) and "
assert self._num_items == len(self._names), (
f"Number of items ({self._num_items}) and "
f"names ({len(self._names)}) must match."
)
else:
......@@ -144,12 +159,11 @@ class ItemSet:
yield from torch.arange(self._items, dtype=dtype)
return
if len(self._items) == 1:
if self._num_items == 1:
yield from self._items[0]
return
if isinstance(self._items[0], Sized):
items_len = len(self._items[0])
if self._length is not None:
# Use for-loop to iterate over the items. It can avoid a long
# waiting time when the items are torch tensors. Since torch
# tensors need to call self.unbind(0) to slice themselves.
......@@ -157,7 +171,7 @@ class ItemSet:
# wait times during the loading phase, and the impact on overall
# performance during the training/testing stage is minimal.
# For more details, see https://github.com/dmlc/dgl/pull/6293.
for i in range(items_len):
for i in range(self._length):
yield tuple(item[i] for item in self._items)
else:
# If the items are not Sized, we use zip to iterate over them.
......@@ -165,31 +179,20 @@ class ItemSet:
for item in zip_items:
yield tuple(item)
def __len__(self) -> int:
if is_scalar(self._items):
return int(self._items)
if isinstance(self._items[0], Sized):
return len(self._items[0])
raise TypeError(
f"{type(self).__name__} instance doesn't have valid length."
)
def __getitem__(self, idx: Union[int, slice, Iterable]) -> Tuple:
try:
len(self)
except TypeError:
if self._length is None:
raise TypeError(
f"{type(self).__name__} instance doesn't support indexing."
)
if is_scalar(self._items):
if isinstance(idx, slice):
start, stop, step = idx.indices(int(self._items))
start, stop, step = idx.indices(self._length)
dtype = getattr(self._items, "dtype", torch.int64)
return torch.arange(start, stop, step, dtype=dtype)
if isinstance(idx, int):
if idx < 0:
idx += self._items
if idx < 0 or idx >= self._items:
idx += self._length
if idx < 0 or idx >= self._length:
raise IndexError(
f"{type(self).__name__} index out of range."
)
......@@ -201,7 +204,7 @@ class ItemSet:
raise TypeError(
f"{type(self).__name__} indices must be integer or slice."
)
if len(self._items) == 1:
if self._num_items == 1:
return self._items[0][idx]
return tuple(item[idx] for item in self._items)
......@@ -210,6 +213,18 @@ class ItemSet:
"""Return the names of the items."""
return self._names
@property
def num_items(self) -> int:
"""Return the number of the items."""
return self._num_items
def __len__(self):
if self._length is None:
raise TypeError(
f"{type(self).__name__} instance doesn't have valid length."
)
return self._length
def __repr__(self) -> str:
ret = (
f"{self.__class__.__name__}(\n"
......@@ -364,8 +379,10 @@ class ItemSetDict:
if stop <= self._offsets[offset_idx]:
break
return data
raise TypeError(f"{type(self).__name__} indices must be int or slice.")
else:
raise TypeError(
f"{type(self).__name__} indices must be int or slice."
)
@property
def names(self) -> Tuple[str]:
......
......@@ -529,7 +529,7 @@ def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace):
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type):
def test_SubgraphSampler_without_deduplication_Homo_seed_nodes(sampler_type):
_check_sampler_type(sampler_type)
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
......@@ -643,7 +643,7 @@ def _assert_homo_values(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Hetero_seed_nodes(sampler_type):
def test_SubgraphSampler_without_deduplication_Hetero_seed_nodes(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2)
......@@ -1409,7 +1409,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Homo_Node(sampler_type):
def test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):
_check_sampler_type(sampler_type)
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
......@@ -1473,7 +1473,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Node(sampler_type):
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Hetero_Node(sampler_type):
def test_SubgraphSampler_without_deduplication_Hetero_Node(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2)
......@@ -1829,7 +1829,7 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type):
def test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):
_check_sampler_type(sampler_type)
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
......@@ -1845,7 +1845,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type):
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(1, 10, (3,)))
items = (items, torch.randint(1, 10, (2,)))
names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names)
......@@ -1891,7 +1891,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type):
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Hetero_Link(sampler_type):
def test_SubgraphSampler_without_deduplication_Hetero_Link(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2).view(1, 2)
......@@ -1903,7 +1903,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero_Link(sampler_type):
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(1, 10, (2,)))
items = (items, torch.randint(1, 10, (1,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict({"n1:e1:n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment