Unverified Commit 3df6e301 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Update `__len__()` and `__getitem__()` of `ItemSet` (#7253)

parent 8e3d8101
"""GraphBolt Itemset.""" """GraphBolt Itemset."""
import textwrap import textwrap
from typing import Dict, Iterable, Iterator, Sized, Tuple, Union from typing import Dict, Iterable, Iterator, Tuple, Union
import torch import torch
...@@ -119,20 +119,35 @@ class ItemSet: ...@@ -119,20 +119,35 @@ class ItemSet:
items: Union[int, torch.Tensor, Iterable, Tuple[Iterable]], items: Union[int, torch.Tensor, Iterable, Tuple[Iterable]],
names: Union[str, Tuple[str]] = None, names: Union[str, Tuple[str]] = None,
) -> None: ) -> None:
if isinstance(items, tuple) or is_scalar(items): if is_scalar(items):
self._length = int(items)
self._items = items self._items = items
self._num_items = 1
elif isinstance(items, tuple):
try:
self._length = len(items[0])
except TypeError:
self._length = None
if self._length is not None:
if any(self._length != len(item) for item in items):
raise ValueError("Size mismatch between items.")
self._items = items
self._num_items = len(items)
else: else:
try:
self._length = len(items)
except TypeError:
self._length = None
self._items = (items,) self._items = (items,)
self._num_items = 1
if names is not None: if names is not None:
num_items = (
len(self._items) if isinstance(self._items, tuple) else 1
)
if isinstance(names, tuple): if isinstance(names, tuple):
self._names = names self._names = names
else: else:
self._names = (names,) self._names = (names,)
assert num_items == len(self._names), ( assert self._num_items == len(self._names), (
f"Number of items ({num_items}) and " f"Number of items ({self._num_items}) and "
f"names ({len(self._names)}) must match." f"names ({len(self._names)}) must match."
) )
else: else:
...@@ -144,12 +159,11 @@ class ItemSet: ...@@ -144,12 +159,11 @@ class ItemSet:
yield from torch.arange(self._items, dtype=dtype) yield from torch.arange(self._items, dtype=dtype)
return return
if len(self._items) == 1: if self._num_items == 1:
yield from self._items[0] yield from self._items[0]
return return
if isinstance(self._items[0], Sized): if self._length is not None:
items_len = len(self._items[0])
# Use for-loop to iterate over the items. It can avoid a long # Use for-loop to iterate over the items. It can avoid a long
# waiting time when the items are torch tensors. Since torch # waiting time when the items are torch tensors. Since torch
# tensors need to call self.unbind(0) to slice themselves. # tensors need to call self.unbind(0) to slice themselves.
...@@ -157,7 +171,7 @@ class ItemSet: ...@@ -157,7 +171,7 @@ class ItemSet:
# wait times during the loading phase, and the impact on overall # wait times during the loading phase, and the impact on overall
# performance during the training/testing stage is minimal. # performance during the training/testing stage is minimal.
# For more details, see https://github.com/dmlc/dgl/pull/6293. # For more details, see https://github.com/dmlc/dgl/pull/6293.
for i in range(items_len): for i in range(self._length):
yield tuple(item[i] for item in self._items) yield tuple(item[i] for item in self._items)
else: else:
# If the items are not Sized, we use zip to iterate over them. # If the items are not Sized, we use zip to iterate over them.
...@@ -165,31 +179,20 @@ class ItemSet: ...@@ -165,31 +179,20 @@ class ItemSet:
for item in zip_items: for item in zip_items:
yield tuple(item) yield tuple(item)
def __len__(self) -> int:
if is_scalar(self._items):
return int(self._items)
if isinstance(self._items[0], Sized):
return len(self._items[0])
raise TypeError(
f"{type(self).__name__} instance doesn't have valid length."
)
def __getitem__(self, idx: Union[int, slice, Iterable]) -> Tuple: def __getitem__(self, idx: Union[int, slice, Iterable]) -> Tuple:
try: if self._length is None:
len(self)
except TypeError:
raise TypeError( raise TypeError(
f"{type(self).__name__} instance doesn't support indexing." f"{type(self).__name__} instance doesn't support indexing."
) )
if is_scalar(self._items): if is_scalar(self._items):
if isinstance(idx, slice): if isinstance(idx, slice):
start, stop, step = idx.indices(int(self._items)) start, stop, step = idx.indices(self._length)
dtype = getattr(self._items, "dtype", torch.int64) dtype = getattr(self._items, "dtype", torch.int64)
return torch.arange(start, stop, step, dtype=dtype) return torch.arange(start, stop, step, dtype=dtype)
if isinstance(idx, int): if isinstance(idx, int):
if idx < 0: if idx < 0:
idx += self._items idx += self._length
if idx < 0 or idx >= self._items: if idx < 0 or idx >= self._length:
raise IndexError( raise IndexError(
f"{type(self).__name__} index out of range." f"{type(self).__name__} index out of range."
) )
...@@ -201,7 +204,7 @@ class ItemSet: ...@@ -201,7 +204,7 @@ class ItemSet:
raise TypeError( raise TypeError(
f"{type(self).__name__} indices must be integer or slice." f"{type(self).__name__} indices must be integer or slice."
) )
if len(self._items) == 1: if self._num_items == 1:
return self._items[0][idx] return self._items[0][idx]
return tuple(item[idx] for item in self._items) return tuple(item[idx] for item in self._items)
...@@ -210,6 +213,18 @@ class ItemSet: ...@@ -210,6 +213,18 @@ class ItemSet:
"""Return the names of the items.""" """Return the names of the items."""
return self._names return self._names
@property
def num_items(self) -> int:
"""Return the number of the items."""
return self._num_items
def __len__(self):
if self._length is None:
raise TypeError(
f"{type(self).__name__} instance doesn't have valid length."
)
return self._length
def __repr__(self) -> str: def __repr__(self) -> str:
ret = ( ret = (
f"{self.__class__.__name__}(\n" f"{self.__class__.__name__}(\n"
...@@ -364,8 +379,10 @@ class ItemSetDict: ...@@ -364,8 +379,10 @@ class ItemSetDict:
if stop <= self._offsets[offset_idx]: if stop <= self._offsets[offset_idx]:
break break
return data return data
else:
raise TypeError(f"{type(self).__name__} indices must be int or slice.") raise TypeError(
f"{type(self).__name__} indices must be int or slice."
)
@property @property
def names(self) -> Tuple[str]: def names(self) -> Tuple[str]:
......
...@@ -529,7 +529,7 @@ def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace): ...@@ -529,7 +529,7 @@ def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type): def test_SubgraphSampler_without_deduplication_Homo_seed_nodes(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = dgl.graph( graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4]) ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
...@@ -643,7 +643,7 @@ def _assert_homo_values( ...@@ -643,7 +643,7 @@ def _assert_homo_values(
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_without_dedpulication_Hetero_seed_nodes(sampler_type): def test_SubgraphSampler_without_deduplication_Hetero_seed_nodes(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2) items = torch.arange(2)
...@@ -1409,7 +1409,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace): ...@@ -1409,7 +1409,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_without_dedpulication_Homo_Node(sampler_type): def test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = dgl.graph( graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4]) ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
...@@ -1473,7 +1473,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Node(sampler_type): ...@@ -1473,7 +1473,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Node(sampler_type):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_without_dedpulication_Hetero_Node(sampler_type): def test_SubgraphSampler_without_deduplication_Hetero_Node(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2) items = torch.arange(2)
...@@ -1829,7 +1829,7 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type): ...@@ -1829,7 +1829,7 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type): def test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = dgl.graph( graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4]) ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
...@@ -1845,7 +1845,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type): ...@@ -1845,7 +1845,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type):
graph.edge_attributes = { graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx()) "timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
} }
items = (items, torch.randint(1, 10, (3,))) items = (items, torch.randint(1, 10, (2,)))
names = (names, "timestamp") names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names) itemset = gb.ItemSet(items, names=names)
...@@ -1891,7 +1891,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type): ...@@ -1891,7 +1891,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_without_dedpulication_Hetero_Link(sampler_type): def test_SubgraphSampler_without_deduplication_Hetero_Link(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2).view(1, 2) items = torch.arange(2).view(1, 2)
...@@ -1903,7 +1903,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero_Link(sampler_type): ...@@ -1903,7 +1903,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero_Link(sampler_type):
graph.edge_attributes = { graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx()) "timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
} }
items = (items, torch.randint(1, 10, (2,))) items = (items, torch.randint(1, 10, (1,)))
names = (names, "timestamp") names = (names, "timestamp")
itemset = gb.ItemSetDict({"n1:e1:n2": gb.ItemSet(items, names=names)}) itemset = gb.ItemSetDict({"n1:e1:n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment