Unverified Commit ec428409 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add typing hint for ItemSet/ItemSetDic/MinibatchSampler (#5841)

parent 9c756a5e
"""GraphBolt Itemset."""
from typing import Dict, Iterable, Iterator, Tuple
__all__ = ["ItemSet", "ItemSetDict"]
......@@ -45,7 +47,7 @@ class ItemSet:
(tensor(4), tensor(9), tensor([18, 19]))]
"""
def __init__(self, items):
def __init__(self, items: Iterable or Tuple[Iterable]) -> None:
if isinstance(items, tuple):
assert all(
items[0].size(0) == item.size(0) for item in items
......@@ -54,7 +56,7 @@ class ItemSet:
else:
self._items = (items,)
def __iter__(self):
def __iter__(self) -> Iterator:
if len(self._items) == 1:
yield from self._items[0]
return
......@@ -119,10 +121,10 @@ class ItemSetDict:
{('user', 'follow', 'user'): (tensor(2), tensor(5), tensor([4, 5]))}]
"""
def __init__(self, itemsets):
def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
self._itemsets = itemsets
def __iter__(self):
def __iter__(self) -> Iterator:
for key, itemset in self._itemsets.items():
for item in itemset:
yield {key: item}
......@@ -2,7 +2,7 @@
from collections.abc import Mapping
from functools import partial
from typing import Optional
from typing import Iterator, Optional
from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
......@@ -174,14 +174,14 @@ class MinibatchSampler(IterDataPipe):
batch_size: int,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
):
) -> None:
super().__init__()
self._item_set = item_set
self._batch_size = batch_size
self._drop_last = drop_last
self._shuffle = shuffle
def __iter__(self):
def __iter__(self) -> Iterator:
data_pipe = IterableWrapper(self._item_set)
# Shuffle before batch.
if self._shuffle:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment