Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
5ada3fc9
Unverified
Commit
5ada3fc9
authored
Jun 08, 2023
by
Rhett Ying
Committed by
GitHub
Jun 08, 2023
Browse files
[GraphBolt] rename DictItemSet as ItemSetDict (#5806)
parent
d88275ca
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
26 deletions
+26
-26
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+3
-3
python/dgl/graphbolt/minibatch_sampler.py
python/dgl/graphbolt/minibatch_sampler.py
+7
-7
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+8
-8
tests/python/pytorch/graphbolt/test_minibatch_sampler.py
tests/python/pytorch/graphbolt/test_minibatch_sampler.py
+8
-8
No files found.
python/dgl/graphbolt/itemset.py
View file @
5ada3fc9
"""GraphBolt Itemset."""
"""GraphBolt Itemset."""
__all__
=
[
"ItemSet"
,
"
Dict
ItemSet"
]
__all__
=
[
"ItemSet"
,
"ItemSet
Dict
"
]
class
ItemSet
:
class
ItemSet
:
...
@@ -36,8 +36,8 @@ class ItemSet:
...
@@ -36,8 +36,8 @@ class ItemSet:
raise
NotImplementedError
raise
NotImplementedError
class
Dict
ItemSet
:
class
ItemSet
Dict
:
r
"""
Itemset wrapping multip
le
i
temset
s with keys
.
r
"""
An iterab
le
I
temset
Dict
.
Each item is retrieved by iterating over each itemset and returned with
Each item is retrieved by iterating over each itemset and returned with
corresponding key as a dict.
corresponding key as a dict.
...
...
python/dgl/graphbolt/minibatch_sampler.py
View file @
5ada3fc9
...
@@ -9,7 +9,7 @@ from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
...
@@ -9,7 +9,7 @@ from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from
..batch
import
batch
as
dgl_batch
from
..batch
import
batch
as
dgl_batch
from
..heterograph
import
DGLGraph
from
..heterograph
import
DGLGraph
from
.itemset
import
Dict
ItemSet
,
ItemSet
from
.itemset
import
ItemSet
,
ItemSet
Dict
__all__
=
[
"MinibatchSampler"
]
__all__
=
[
"MinibatchSampler"
]
...
@@ -28,7 +28,7 @@ class MinibatchSampler(IterDataPipe):
...
@@ -28,7 +28,7 @@ class MinibatchSampler(IterDataPipe):
Parameters
Parameters
----------
----------
item_set : ItemSet or
Dict
ItemSet
item_set : ItemSet or ItemSet
Dict
Data to be sampled for mini-batches.
Data to be sampled for mini-batches.
batch_size : int
batch_size : int
The size of each batch.
The size of each batch.
...
@@ -106,7 +106,7 @@ class MinibatchSampler(IterDataPipe):
...
@@ -106,7 +106,7 @@ class MinibatchSampler(IterDataPipe):
... "user": gb.ItemSet(torch.arange(0, 5)),
... "user": gb.ItemSet(torch.arange(0, 5)),
... "item": gb.ItemSet(torch.arange(0, 6)),
... "item": gb.ItemSet(torch.arange(0, 6)),
... }
... }
>>> item_set = gb.
Dict
ItemSet(ids)
>>> item_set = gb.ItemSet
Dict
(ids)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
>>> list(minibatch_sampler)
[{'user': tensor([0, 1, 2, 3])},
[{'user': tensor([0, 1, 2, 3])},
...
@@ -116,7 +116,7 @@ class MinibatchSampler(IterDataPipe):
...
@@ -116,7 +116,7 @@ class MinibatchSampler(IterDataPipe):
8. Heterogeneous node pairs.
8. Heterogeneous node pairs.
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5))
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5))
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12))
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12))
>>> item_set = gb.
Dict
ItemSet({
>>> item_set = gb.ItemSet
Dict
({
... ("user", "like", "item"): gb.ItemSet(node_pairs_like),
... ("user", "like", "item"): gb.ItemSet(node_pairs_like),
... ("user", "follow", "user"): gb.ItemSet(node_pairs_follow),
... ("user", "follow", "user"): gb.ItemSet(node_pairs_follow),
... })
... })
...
@@ -132,7 +132,7 @@ class MinibatchSampler(IterDataPipe):
...
@@ -132,7 +132,7 @@ class MinibatchSampler(IterDataPipe):
... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5))
... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5))
>>> follow = (
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6))
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6))
>>> item_set = gb.
Dict
ItemSet({
>>> item_set = gb.ItemSet
Dict
({
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow),
... ("user", "follow", "user"): gb.ItemSet(follow),
... })
... })
...
@@ -153,7 +153,7 @@ class MinibatchSampler(IterDataPipe):
...
@@ -153,7 +153,7 @@ class MinibatchSampler(IterDataPipe):
>>> follow = (
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2))
... torch.arange(12, 24).reshape(-1, 2))
>>> item_set = gb.
Dict
ItemSet({
>>> item_set = gb.ItemSet
Dict
({
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow),
... ("user", "follow", "user"): gb.ItemSet(follow),
... })
... })
...
@@ -170,7 +170,7 @@ class MinibatchSampler(IterDataPipe):
...
@@ -170,7 +170,7 @@ class MinibatchSampler(IterDataPipe):
def
__init__
(
def
__init__
(
self
,
self
,
item_set
:
ItemSet
or
Dict
ItemSet
,
item_set
:
ItemSet
or
ItemSet
Dict
,
batch_size
:
int
,
batch_size
:
int
,
drop_last
:
Optional
[
bool
]
=
False
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
...
...
tests/python/pytorch/graphbolt/test_itemset.py
View file @
5ada3fc9
...
@@ -51,7 +51,7 @@ def test_ItemSet_head_tail_neg_tails():
...
@@ -51,7 +51,7 @@ def test_ItemSet_head_tail_neg_tails():
assert_close
(
neg_tails
[
i
],
negs
)
assert_close
(
neg_tails
[
i
],
negs
)
def
test_
Dict
ItemSet_node_edge_ids
():
def
test_ItemSet
Dict
_node_edge_ids
():
# Node or edge IDs
# Node or edge IDs
ids
=
{
ids
=
{
(
"user"
,
"like"
,
"item"
):
ItemSet
(
torch
.
arange
(
0
,
5
)),
(
"user"
,
"like"
,
"item"
):
ItemSet
(
torch
.
arange
(
0
,
5
)),
...
@@ -60,7 +60,7 @@ def test_DictItemSet_node_edge_ids():
...
@@ -60,7 +60,7 @@ def test_DictItemSet_node_edge_ids():
chained_ids
=
[]
chained_ids
=
[]
for
key
,
value
in
ids
.
items
():
for
key
,
value
in
ids
.
items
():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
Dict
ItemSet
(
ids
)
item_set
=
ItemSet
Dict
(
ids
)
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
isinstance
(
item
,
dict
)
...
@@ -68,7 +68,7 @@ def test_DictItemSet_node_edge_ids():
...
@@ -68,7 +68,7 @@ def test_DictItemSet_node_edge_ids():
assert
item
[
chained_ids
[
i
][
0
]]
==
chained_ids
[
i
][
1
]
assert
item
[
chained_ids
[
i
][
0
]]
==
chained_ids
[
i
][
1
]
def
test_
Dict
ItemSet_node_pairs
():
def
test_ItemSet
Dict
_node_pairs
():
# Node pairs.
# Node pairs.
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
node_pairs_dict
=
{
node_pairs_dict
=
{
...
@@ -78,7 +78,7 @@ def test_DictItemSet_node_pairs():
...
@@ -78,7 +78,7 @@ def test_DictItemSet_node_pairs():
expected_data
=
[]
expected_data
=
[]
for
key
,
value
in
node_pairs_dict
.
items
():
for
key
,
value
in
node_pairs_dict
.
items
():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
Dict
ItemSet
(
node_pairs_dict
)
item_set
=
ItemSet
Dict
(
node_pairs_dict
)
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
isinstance
(
item
,
dict
)
...
@@ -86,7 +86,7 @@ def test_DictItemSet_node_pairs():
...
@@ -86,7 +86,7 @@ def test_DictItemSet_node_pairs():
assert
item
[
expected_data
[
i
][
0
]]
==
expected_data
[
i
][
1
]
assert
item
[
expected_data
[
i
][
0
]]
==
expected_data
[
i
][
1
]
def
test_
Dict
ItemSet_node_pairs_labels
():
def
test_ItemSet
Dict
_node_pairs_labels
():
# Node pairs and labels
# Node pairs and labels
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
...
@@ -101,7 +101,7 @@ def test_DictItemSet_node_pairs_labels():
...
@@ -101,7 +101,7 @@ def test_DictItemSet_node_pairs_labels():
expected_data
=
[]
expected_data
=
[]
for
key
,
value
in
node_pairs_dict
.
items
():
for
key
,
value
in
node_pairs_dict
.
items
():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
Dict
ItemSet
(
node_pairs_dict
)
item_set
=
ItemSet
Dict
(
node_pairs_dict
)
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
isinstance
(
item
,
dict
)
...
@@ -109,7 +109,7 @@ def test_DictItemSet_node_pairs_labels():
...
@@ -109,7 +109,7 @@ def test_DictItemSet_node_pairs_labels():
assert
item
[
expected_data
[
i
][
0
]]
==
expected_data
[
i
][
1
]
assert
item
[
expected_data
[
i
][
0
]]
==
expected_data
[
i
][
1
]
def
test_
Dict
ItemSet_head_tail_neg_tails
():
def
test_ItemSet
Dict
_head_tail_neg_tails
():
# Head, tail and negative tails.
# Head, tail and negative tails.
heads
=
torch
.
arange
(
0
,
5
)
heads
=
torch
.
arange
(
0
,
5
)
tails
=
torch
.
arange
(
5
,
10
)
tails
=
torch
.
arange
(
5
,
10
)
...
@@ -122,7 +122,7 @@ def test_DictItemSet_head_tail_neg_tails():
...
@@ -122,7 +122,7 @@ def test_DictItemSet_head_tail_neg_tails():
expected_data
=
[]
expected_data
=
[]
for
key
,
value
in
data_dict
.
items
():
for
key
,
value
in
data_dict
.
items
():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
Dict
ItemSet
(
data_dict
)
item_set
=
ItemSet
Dict
(
data_dict
)
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
isinstance
(
item
,
dict
)
...
...
tests/python/pytorch/graphbolt/test_minibatch_sampler.py
View file @
5ada3fc9
...
@@ -215,7 +215,7 @@ def test_append_with_other_datapipes():
...
@@ -215,7 +215,7 @@ def test_append_with_other_datapipes():
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_
Dict
ItemSet_node_ids
(
batch_size
,
shuffle
,
drop_last
):
def
test_ItemSet
Dict
_node_ids
(
batch_size
,
shuffle
,
drop_last
):
# Node IDs.
# Node IDs.
num_ids
=
205
num_ids
=
205
ids
=
{
ids
=
{
...
@@ -225,7 +225,7 @@ def test_DictItemSet_node_ids(batch_size, shuffle, drop_last):
...
@@ -225,7 +225,7 @@ def test_DictItemSet_node_ids(batch_size, shuffle, drop_last):
chained_ids
=
[]
chained_ids
=
[]
for
key
,
value
in
ids
.
items
():
for
key
,
value
in
ids
.
items
():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
Dict
ItemSet
(
ids
)
item_set
=
gb
.
ItemSet
Dict
(
ids
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
...
@@ -253,7 +253,7 @@ def test_DictItemSet_node_ids(batch_size, shuffle, drop_last):
...
@@ -253,7 +253,7 @@ def test_DictItemSet_node_ids(batch_size, shuffle, drop_last):
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_
Dict
ItemSet_node_pairs
(
batch_size
,
shuffle
,
drop_last
):
def
test_ItemSet
Dict
_node_pairs
(
batch_size
,
shuffle
,
drop_last
):
# Node pairs.
# Node pairs.
num_ids
=
103
num_ids
=
103
total_ids
=
2
*
num_ids
total_ids
=
2
*
num_ids
...
@@ -269,7 +269,7 @@ def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last):
...
@@ -269,7 +269,7 @@ def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last):
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
(
node_pairs_0
),
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
(
node_pairs_0
),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
(
node_pairs_1
),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
(
node_pairs_1
),
}
}
item_set
=
gb
.
Dict
ItemSet
(
node_pairs_dict
)
item_set
=
gb
.
ItemSet
Dict
(
node_pairs_dict
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
...
@@ -305,7 +305,7 @@ def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last):
...
@@ -305,7 +305,7 @@ def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last):
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_
Dict
ItemSet_node_pairs_labels
(
batch_size
,
shuffle
,
drop_last
):
def
test_ItemSet
Dict
_node_pairs_labels
(
batch_size
,
shuffle
,
drop_last
):
# Node pairs and labels
# Node pairs and labels
num_ids
=
103
num_ids
=
103
total_ids
=
2
*
num_ids
total_ids
=
2
*
num_ids
...
@@ -326,7 +326,7 @@ def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
...
@@ -326,7 +326,7 @@ def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
(
node_pairs_1
[
0
],
node_pairs_1
[
1
],
labels
+
num_ids
*
2
)
(
node_pairs_1
[
0
],
node_pairs_1
[
1
],
labels
+
num_ids
*
2
)
),
),
}
}
item_set
=
gb
.
Dict
ItemSet
(
node_pairs_dict
)
item_set
=
gb
.
ItemSet
Dict
(
node_pairs_dict
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
...
@@ -371,7 +371,7 @@ def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
...
@@ -371,7 +371,7 @@ def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_
Dict
ItemSet_head_tail_neg_tails
(
batch_size
,
shuffle
,
drop_last
):
def
test_ItemSet
Dict
_head_tail_neg_tails
(
batch_size
,
shuffle
,
drop_last
):
# Head, tail and negative tails.
# Head, tail and negative tails.
num_ids
=
103
num_ids
=
103
total_ids
=
2
*
num_ids
total_ids
=
2
*
num_ids
...
@@ -383,7 +383,7 @@ def test_DictItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
...
@@ -383,7 +383,7 @@ def test_DictItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
}
}
item_set
=
gb
.
Dict
ItemSet
(
data_dict
)
item_set
=
gb
.
ItemSet
Dict
(
data_dict
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment