Unverified Commit 37d70e54 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] split node_pairs to tuple of (src, dst) (#6291)

parent 50b05723
...@@ -67,6 +67,13 @@ def minibatcher_default(batch, names): ...@@ -67,6 +67,13 @@ def minibatcher_default(batch, names):
"`MiniBatch`. You probably need to provide a customized " "`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`." "`MiniBatcher`."
) )
if name == "node_pairs":
# `node_pairs` is passed as a tensor in shape of `(N, 2)` and
# should be converted to a tuple of `(src, dst)`.
if isinstance(item, Mapping):
item = {key: (item[key][:, 0], item[key][:, 1]) for key in item}
else:
item = (item[:, 0], item[:, 1])
setattr(minibatch, name, item) setattr(minibatch, name, item)
return minibatch return minibatch
...@@ -103,10 +110,10 @@ class ItemSampler(IterDataPipe): ...@@ -103,10 +110,10 @@ class ItemSampler(IterDataPipe):
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes") >>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False ... item_set, batch_size=4, shuffle=False, drop_last=False
... ) ... )
>>> next(iter(item_sampler)) >>> next(iter(item_sampler))
MiniBatch(seed_nodes=tensor([9, 0, 7, 2]), node_pairs=None, labels=None, MiniBatch(seed_nodes=tensor([0, 1, 2, 3]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None, input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None, compacted_node_pairs=None, compacted_negative_srcs=None,
...@@ -116,30 +123,28 @@ class ItemSampler(IterDataPipe): ...@@ -116,30 +123,28 @@ class ItemSampler(IterDataPipe):
>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), >>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs") ... names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False ... item_set, batch_size=4, shuffle=False, drop_last=False
... ) ... )
>>> next(iter(item_sampler)) >>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[16, 17], MiniBatch(seed_nodes=None,
[ 4, 5], node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
[ 6, 7], labels=None, negative_srcs=None, negative_dsts=None,
[10, 11]]), labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None, sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None) compacted_negative_srcs=None, compacted_negative_dsts=None)
3. Node pairs and labels. 3. Node pairs and labels.
>>> item_set = gb.ItemSet( >>> item_set = gb.ItemSet(
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 15)), ... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 20)),
... names=("node_pairs", "labels") ... names=("node_pairs", "labels")
... ) ... )
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False ... item_set, batch_size=4, shuffle=False, drop_last=False
... ) ... )
>>> next(iter(item_sampler)) >>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[8, 9], MiniBatch(seed_nodes=None,
[4, 5], node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
[0, 1], labels=tensor([10, 11, 12, 13]), negative_srcs=None,
[6, 7]]), labels=tensor([14, 12, 10, 13]), negative_srcs=None,
negative_dsts=None, sampled_subgraphs=None, input_nodes=None, negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None, node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None) compacted_negative_srcs=None, compacted_negative_dsts=None)
...@@ -150,17 +155,16 @@ class ItemSampler(IterDataPipe): ...@@ -150,17 +155,16 @@ class ItemSampler(IterDataPipe):
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs", >>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts")) ... "negative_dsts"))
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False ... item_set, batch_size=4, shuffle=False, drop_last=False
... ) ... )
>>> next(iter(item_sampler)) >>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[10, 11], MiniBatch(seed_nodes=None,
[ 6, 7], node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
[ 2, 3], labels=None, negative_srcs=None,
[ 8, 9]]), labels=None, negative_srcs=None, negative_dsts=tensor([[10, 11],
negative_dsts=tensor([[20, 21],
[16, 17],
[12, 13], [12, 13],
[18, 19]]), sampled_subgraphs=None, input_nodes=None, [14, 15],
[16, 17]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None, node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None) compacted_negative_srcs=None, compacted_negative_dsts=None)
...@@ -212,10 +216,10 @@ class ItemSampler(IterDataPipe): ...@@ -212,10 +216,10 @@ class ItemSampler(IterDataPipe):
... }) ... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler)) >>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1], MiniBatch(seed_nodes=None,
[2, 3], node_pairs={'user:like:item':
[4, 5], (tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
[6, 7]])}, labels=None, negative_srcs=None, negative_dsts=None, labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None, sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None) compacted_negative_srcs=None, compacted_negative_dsts=None)
...@@ -233,10 +237,10 @@ class ItemSampler(IterDataPipe): ...@@ -233,10 +237,10 @@ class ItemSampler(IterDataPipe):
... }) ... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler)) >>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1], MiniBatch(seed_nodes=None,
[2, 3], node_pairs={'user:like:item':
[4, 5], (tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
[6, 7]])}, labels={'user:like:item': tensor([0, 1, 2, 3])}, labels={'user:like:item': tensor([0, 1, 2, 3])},
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None, input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None, compacted_node_pairs=None, compacted_negative_srcs=None,
...@@ -255,10 +259,10 @@ class ItemSampler(IterDataPipe): ...@@ -255,10 +259,10 @@ class ItemSampler(IterDataPipe):
... }) ... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler)) >>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1], MiniBatch(seed_nodes=None,
[2, 3], node_pairs={'user:like:item':
[4, 5], (tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
[6, 7]])}, labels=None, negative_srcs=None, labels=None, negative_srcs=None,
negative_dsts={'user:like:item': tensor([[10, 11], negative_dsts={'user:like:item': tensor([[10, 11],
[12, 13], [12, 13],
[14, 15], [14, 15],
......
...@@ -15,8 +15,9 @@ class ItemSet: ...@@ -15,8 +15,9 @@ class ItemSet:
Parameters Parameters
---------- ----------
items: Iterable or Tuple[Iterable] items: Iterable or Tuple[Iterable]
The items to be iterated over. If it is a tuple, each item in the tuple The items to be iterated over. If it's multi-dimensional iterable such
is an iterable of items. as `torch.Tensor`, it will be iterated over the first dimension. If it
is a tuple, each item in the tuple is an iterable of items.
names: str or Tuple[str], optional names: str or Tuple[str], optional
The names of the items. If it is a tuple, each name corresponds to an The names of the items. If it is a tuple, each name corresponds to an
item in the tuple. item in the tuple.
......
...@@ -183,9 +183,9 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last): ...@@ -183,9 +183,9 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
dst_ids = [] dst_ids = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None assert minibatch.node_pairs is not None
assert isinstance(minibatch.node_pairs, tuple)
assert minibatch.labels is None assert minibatch.labels is None
src = minibatch.node_pairs[:, 0] src, dst = minibatch.node_pairs
dst = minibatch.node_pairs[:, 1]
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -224,11 +224,12 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -224,11 +224,12 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
labels = [] labels = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None assert minibatch.node_pairs is not None
assert isinstance(minibatch.node_pairs, tuple)
assert minibatch.labels is not None assert minibatch.labels is not None
assert len(minibatch.node_pairs) == len(minibatch.labels) src, dst = minibatch.node_pairs
src = minibatch.node_pairs[:, 0]
dst = minibatch.node_pairs[:, 1]
label = minibatch.labels label = minibatch.labels
assert len(src) == len(dst)
assert len(src) == len(label)
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -277,9 +278,9 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -277,9 +278,9 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
negs_ids = [] negs_ids = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None assert minibatch.node_pairs is not None
assert isinstance(minibatch.node_pairs, tuple)
assert minibatch.negative_dsts is not None assert minibatch.negative_dsts is not None
src = minibatch.node_pairs[:, 0] src, dst = minibatch.node_pairs
dst = minibatch.node_pairs[:, 1]
negs = minibatch.negative_dsts negs = minibatch.negative_dsts
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
...@@ -451,9 +452,10 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last): ...@@ -451,9 +452,10 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
assert False assert False
src = [] src = []
dst = [] dst = []
for _, node_pairs in minibatch.node_pairs.items(): for _, (node_pairs) in minibatch.node_pairs.items():
src.append(node_pairs[:, 0]) assert isinstance(node_pairs, tuple)
dst.append(node_pairs[:, 1]) src.append(node_pairs[0])
dst.append(node_pairs[1])
src = torch.cat(src) src = torch.cat(src)
dst = torch.cat(dst) dst = torch.cat(dst)
assert len(src) == expected_batch_size assert len(src) == expected_batch_size
...@@ -510,8 +512,9 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -510,8 +512,9 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
dst = [] dst = []
label = [] label = []
for _, node_pairs in minibatch.node_pairs.items(): for _, node_pairs in minibatch.node_pairs.items():
src.append(node_pairs[:, 0]) assert isinstance(node_pairs, tuple)
dst.append(node_pairs[:, 1]) src.append(node_pairs[0])
dst.append(node_pairs[1])
for _, v_label in minibatch.labels.items(): for _, v_label in minibatch.labels.items():
label.append(v_label) label.append(v_label)
src = torch.cat(src) src = torch.cat(src)
...@@ -582,8 +585,9 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -582,8 +585,9 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
dst = [] dst = []
negs = [] negs = []
for _, node_pairs in minibatch.node_pairs.items(): for _, node_pairs in minibatch.node_pairs.items():
src.append(node_pairs[:, 0]) assert isinstance(node_pairs, tuple)
dst.append(node_pairs[:, 1]) src.append(node_pairs[0])
dst.append(node_pairs[1])
for _, v_negs in minibatch.negative_dsts.items(): for _, v_negs in minibatch.negative_dsts.items():
negs.append(v_negs) negs.append(v_negs)
src = torch.cat(src) src = torch.cat(src)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment