Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ec428409
Unverified
Commit
ec428409
authored
Jun 09, 2023
by
Rhett Ying
Committed by
GitHub
Jun 09, 2023
Browse files
[GraphBolt] add typing hint for ItemSet/ItemSetDic/MinibatchSampler (#5841)
parent
9c756a5e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
7 deletions
+9
-7
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+6
-4
python/dgl/graphbolt/minibatch_sampler.py
python/dgl/graphbolt/minibatch_sampler.py
+3
-3
No files found.
python/dgl/graphbolt/itemset.py
View file @
ec428409
"""GraphBolt Itemset."""
from
typing
import
Dict
,
Iterable
,
Iterator
,
Tuple
__all__
=
[
"ItemSet"
,
"ItemSetDict"
]
...
...
@@ -45,7 +47,7 @@ class ItemSet:
(tensor(4), tensor(9), tensor([18, 19]))]
"""
def
__init__
(
self
,
items
)
:
def
__init__
(
self
,
items
:
Iterable
or
Tuple
[
Iterable
])
->
None
:
if
isinstance
(
items
,
tuple
):
assert
all
(
items
[
0
].
size
(
0
)
==
item
.
size
(
0
)
for
item
in
items
...
...
@@ -54,7 +56,7 @@ class ItemSet:
else
:
self
.
_items
=
(
items
,)
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
:
if
len
(
self
.
_items
)
==
1
:
yield
from
self
.
_items
[
0
]
return
...
...
@@ -119,10 +121,10 @@ class ItemSetDict:
{('user', 'follow', 'user'): (tensor(2), tensor(5), tensor([4, 5]))}]
"""
def
__init__
(
self
,
itemsets
)
:
def
__init__
(
self
,
itemsets
:
Dict
[
str
,
ItemSet
])
->
None
:
self
.
_itemsets
=
itemsets
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
:
for
key
,
itemset
in
self
.
_itemsets
.
items
():
for
item
in
itemset
:
yield
{
key
:
item
}
python/dgl/graphbolt/minibatch_sampler.py
View file @
ec428409
...
...
@@ -2,7 +2,7 @@
from
collections.abc
import
Mapping
from
functools
import
partial
from
typing
import
Optional
from
typing
import
Iterator
,
Optional
from
torch.utils.data
import
default_collate
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
...
...
@@ -174,14 +174,14 @@ class MinibatchSampler(IterDataPipe):
batch_size
:
int
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
):
)
->
None
:
super
().
__init__
()
self
.
_item_set
=
item_set
self
.
_batch_size
=
batch_size
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
:
data_pipe
=
IterableWrapper
(
self
.
_item_set
)
# Shuffle before batch.
if
self
.
_shuffle
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment