"tests/python/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "8b839a2398467258028dee51738c6a50c77bb5f9"
Unverified Commit ec428409 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add typing hint for ItemSet/ItemSetDic/MinibatchSampler (#5841)

parent 9c756a5e
"""GraphBolt Itemset.""" """GraphBolt Itemset."""
from typing import Dict, Iterable, Iterator, Tuple
__all__ = ["ItemSet", "ItemSetDict"] __all__ = ["ItemSet", "ItemSetDict"]
...@@ -45,7 +47,7 @@ class ItemSet: ...@@ -45,7 +47,7 @@ class ItemSet:
(tensor(4), tensor(9), tensor([18, 19]))] (tensor(4), tensor(9), tensor([18, 19]))]
""" """
def __init__(self, items): def __init__(self, items: Iterable or Tuple[Iterable]) -> None:
if isinstance(items, tuple): if isinstance(items, tuple):
assert all( assert all(
items[0].size(0) == item.size(0) for item in items items[0].size(0) == item.size(0) for item in items
...@@ -54,7 +56,7 @@ class ItemSet: ...@@ -54,7 +56,7 @@ class ItemSet:
else: else:
self._items = (items,) self._items = (items,)
def __iter__(self): def __iter__(self) -> Iterator:
if len(self._items) == 1: if len(self._items) == 1:
yield from self._items[0] yield from self._items[0]
return return
...@@ -119,10 +121,10 @@ class ItemSetDict: ...@@ -119,10 +121,10 @@ class ItemSetDict:
{('user', 'follow', 'user'): (tensor(2), tensor(5), tensor([4, 5]))}] {('user', 'follow', 'user'): (tensor(2), tensor(5), tensor([4, 5]))}]
""" """
def __init__(self, itemsets): def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
self._itemsets = itemsets self._itemsets = itemsets
def __iter__(self): def __iter__(self) -> Iterator:
for key, itemset in self._itemsets.items(): for key, itemset in self._itemsets.items():
for item in itemset: for item in itemset:
yield {key: item} yield {key: item}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from functools import partial from functools import partial
from typing import Optional from typing import Iterator, Optional
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
...@@ -174,14 +174,14 @@ class MinibatchSampler(IterDataPipe): ...@@ -174,14 +174,14 @@ class MinibatchSampler(IterDataPipe):
batch_size: int, batch_size: int,
drop_last: Optional[bool] = False, drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False, shuffle: Optional[bool] = False,
): ) -> None:
super().__init__() super().__init__()
self._item_set = item_set self._item_set = item_set
self._batch_size = batch_size self._batch_size = batch_size
self._drop_last = drop_last self._drop_last = drop_last
self._shuffle = shuffle self._shuffle = shuffle
def __iter__(self): def __iter__(self) -> Iterator:
data_pipe = IterableWrapper(self._item_set) data_pipe = IterableWrapper(self._item_set)
# Shuffle before batch. # Shuffle before batch.
if self._shuffle: if self._shuffle:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment