Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
37d70e54
Unverified
Commit
37d70e54
authored
Sep 06, 2023
by
Rhett Ying
Committed by
GitHub
Sep 06, 2023
Browse files
[GraphBolt] split node_pairs to tuple of (src, dst) (#6291)
parent
50b05723
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
49 deletions
+58
-49
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+37
-33
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+3
-2
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+18
-14
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
37d70e54
...
...
@@ -67,6 +67,13 @@ def minibatcher_default(batch, names):
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
)
if
name
==
"node_pairs"
:
# `node_pairs` is passed as a tensor in shape of `(N, 2)` and
# should be converted to a tuple of `(src, dst)`.
if
isinstance
(
item
,
Mapping
):
item
=
{
key
:
(
item
[
key
][:,
0
],
item
[
key
][:,
1
])
for
key
in
item
}
else
:
item
=
(
item
[:,
0
],
item
[:,
1
])
setattr
(
minibatch
,
name
,
item
)
return
minibatch
...
...
@@ -103,10 +110,10 @@ class ItemSampler(IterDataPipe):
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=
Tru
e, drop_last=False
... item_set, batch_size=4, shuffle=
Fals
e, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=tensor([
9,
0,
7
, 2]), node_pairs=None, labels=None,
MiniBatch(seed_nodes=tensor([0,
1
, 2
, 3
]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
...
...
@@ -116,30 +123,28 @@ class ItemSampler(IterDataPipe):
>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=
Tru
e, drop_last=False
... item_set, batch_size=4, shuffle=
Fals
e, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[16, 17],
[ 4, 5],
[ 6, 7],
[10, 11]]), labels=None, negative_srcs=None, negative_dsts=None,
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
3. Node pairs and labels.
>>> item_set = gb.ItemSet(
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10,
15
)),
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10,
20
)),
... names=("node_pairs", "labels")
... )
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=
Tru
e, drop_last=False
... item_set, batch_size=4, shuffle=
Fals
e, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[8, 9],
[4, 5],
[0, 1],
[6, 7]]), labels=tensor([14, 12, 10, 13]), negative_srcs=None,
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=tensor([10, 11, 12, 13]), negative_srcs=None,
negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
...
...
@@ -150,17 +155,16 @@ class ItemSampler(IterDataPipe):
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts"))
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=
Tru
e, drop_last=False
... item_set, batch_size=4, shuffle=
Fals
e, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[10, 11],
[ 6, 7],
[ 2, 3],
[ 8, 9]]), labels=None, negative_srcs=None,
negative_dsts=tensor([[20, 21],
[16, 17],
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=None, negative_srcs=None,
negative_dsts=tensor([[10, 11],
[12, 13],
[18, 19]]), sampled_subgraphs=None, input_nodes=None,
[14, 15],
[16, 17]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
...
...
@@ -212,10 +216,10 @@ class ItemSampler(IterDataPipe):
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5]
,
[6, 7]])},
labels=None, negative_srcs=None, negative_dsts=None,
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))}
,
labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
...
...
@@ -233,10 +237,10 @@ class ItemSampler(IterDataPipe):
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5]
,
[6, 7]])},
labels={'user:like:item': tensor([0, 1, 2, 3])},
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))}
,
labels={'user:like:item': tensor([0, 1, 2, 3])},
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
...
...
@@ -255,10 +259,10 @@ class ItemSampler(IterDataPipe):
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5]
,
[6, 7]])},
labels=None, negative_srcs=None,
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))}
,
labels=None, negative_srcs=None,
negative_dsts={'user:like:item': tensor([[10, 11],
[12, 13],
[14, 15],
...
...
python/dgl/graphbolt/itemset.py
View file @
37d70e54
...
...
@@ -15,8 +15,9 @@ class ItemSet:
Parameters
----------
items: Iterable or Tuple[Iterable]
The items to be iterated over. If it is a tuple, each item in the tuple
is an iterable of items.
The items to be iterated over. If it's multi-dimensional iterable such
as `torch.Tensor`, it will be iterated over the first dimension. If it
is a tuple, each item in the tuple is an iterable of items.
names: str or Tuple[str], optional
The names of the items. If it is a tuple, each name corresponds to an
item in the tuple.
...
...
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
37d70e54
...
...
@@ -183,9 +183,9 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
dst_ids
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
minibatch
.
node_pairs
is
not
None
assert
isinstance
(
minibatch
.
node_pairs
,
tuple
)
assert
minibatch
.
labels
is
None
src
=
minibatch
.
node_pairs
[:,
0
]
dst
=
minibatch
.
node_pairs
[:,
1
]
src
,
dst
=
minibatch
.
node_pairs
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
...
...
@@ -224,11 +224,12 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
labels
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
minibatch
.
node_pairs
is
not
None
assert
isinstance
(
minibatch
.
node_pairs
,
tuple
)
assert
minibatch
.
labels
is
not
None
assert
len
(
minibatch
.
node_pairs
)
==
len
(
minibatch
.
labels
)
src
=
minibatch
.
node_pairs
[:,
0
]
dst
=
minibatch
.
node_pairs
[:,
1
]
src
,
dst
=
minibatch
.
node_pairs
label
=
minibatch
.
labels
assert
len
(
src
)
==
len
(
dst
)
assert
len
(
src
)
==
len
(
label
)
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
...
...
@@ -277,9 +278,9 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
negs_ids
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
minibatch
.
node_pairs
is
not
None
assert
isinstance
(
minibatch
.
node_pairs
,
tuple
)
assert
minibatch
.
negative_dsts
is
not
None
src
=
minibatch
.
node_pairs
[:,
0
]
dst
=
minibatch
.
node_pairs
[:,
1
]
src
,
dst
=
minibatch
.
node_pairs
negs
=
minibatch
.
negative_dsts
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
...
...
@@ -451,9 +452,10 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
assert
False
src
=
[]
dst
=
[]
for
_
,
node_pairs
in
minibatch
.
node_pairs
.
items
():
src
.
append
(
node_pairs
[:,
0
])
dst
.
append
(
node_pairs
[:,
1
])
for
_
,
(
node_pairs
)
in
minibatch
.
node_pairs
.
items
():
assert
isinstance
(
node_pairs
,
tuple
)
src
.
append
(
node_pairs
[
0
])
dst
.
append
(
node_pairs
[
1
])
src
=
torch
.
cat
(
src
)
dst
=
torch
.
cat
(
dst
)
assert
len
(
src
)
==
expected_batch_size
...
...
@@ -510,8 +512,9 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
dst
=
[]
label
=
[]
for
_
,
node_pairs
in
minibatch
.
node_pairs
.
items
():
src
.
append
(
node_pairs
[:,
0
])
dst
.
append
(
node_pairs
[:,
1
])
assert
isinstance
(
node_pairs
,
tuple
)
src
.
append
(
node_pairs
[
0
])
dst
.
append
(
node_pairs
[
1
])
for
_
,
v_label
in
minibatch
.
labels
.
items
():
label
.
append
(
v_label
)
src
=
torch
.
cat
(
src
)
...
...
@@ -582,8 +585,9 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
dst
=
[]
negs
=
[]
for
_
,
node_pairs
in
minibatch
.
node_pairs
.
items
():
src
.
append
(
node_pairs
[:,
0
])
dst
.
append
(
node_pairs
[:,
1
])
assert
isinstance
(
node_pairs
,
tuple
)
src
.
append
(
node_pairs
[
0
])
dst
.
append
(
node_pairs
[
1
])
for
_
,
v_negs
in
minibatch
.
negative_dsts
.
items
():
negs
.
append
(
v_negs
)
src
=
torch
.
cat
(
src
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment