Unverified Commit 37d70e54 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] split node_pairs to tuple of (src, dst) (#6291)

parent 50b05723
......@@ -67,6 +67,13 @@ def minibatcher_default(batch, names):
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
)
if name == "node_pairs":
# `node_pairs` is passed as a tensor in shape of `(N, 2)` and
# should be converted to a tuple of `(src, dst)`.
if isinstance(item, Mapping):
item = {key: (item[key][:, 0], item[key][:, 1]) for key in item}
else:
item = (item[:, 0], item[:, 1])
setattr(minibatch, name, item)
return minibatch
......@@ -103,10 +110,10 @@ class ItemSampler(IterDataPipe):
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=tensor([9, 0, 7, 2]), node_pairs=None, labels=None,
MiniBatch(seed_nodes=tensor([0, 1, 2, 3]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
......@@ -116,30 +123,28 @@ class ItemSampler(IterDataPipe):
>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[16, 17],
[ 4, 5],
[ 6, 7],
[10, 11]]), labels=None, negative_srcs=None, negative_dsts=None,
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
3. Node pairs and labels.
>>> item_set = gb.ItemSet(
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 15)),
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 20)),
... names=("node_pairs", "labels")
... )
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[8, 9],
[4, 5],
[0, 1],
[6, 7]]), labels=tensor([14, 12, 10, 13]), negative_srcs=None,
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=tensor([10, 11, 12, 13]), negative_srcs=None,
negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
......@@ -150,17 +155,16 @@ class ItemSampler(IterDataPipe):
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts"))
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[10, 11],
[ 6, 7],
[ 2, 3],
[ 8, 9]]), labels=None, negative_srcs=None,
negative_dsts=tensor([[20, 21],
[16, 17],
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=None, negative_srcs=None,
negative_dsts=tensor([[10, 11],
[12, 13],
[18, 19]]), sampled_subgraphs=None, input_nodes=None,
[14, 15],
[16, 17]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
......@@ -212,10 +216,10 @@ class ItemSampler(IterDataPipe):
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels=None, negative_srcs=None, negative_dsts=None,
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
......@@ -233,10 +237,10 @@ class ItemSampler(IterDataPipe):
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels={'user:like:item': tensor([0, 1, 2, 3])},
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
labels={'user:like:item': tensor([0, 1, 2, 3])},
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
......@@ -255,10 +259,10 @@ class ItemSampler(IterDataPipe):
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels=None, negative_srcs=None,
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
labels=None, negative_srcs=None,
negative_dsts={'user:like:item': tensor([[10, 11],
[12, 13],
[14, 15],
......
......@@ -15,8 +15,9 @@ class ItemSet:
Parameters
----------
items: Iterable or Tuple[Iterable]
The items to be iterated over. If it is a tuple, each item in the tuple
is an iterable of items.
The items to be iterated over. If it's multi-dimensional iterable such
as `torch.Tensor`, it will be iterated over the first dimension. If it
is a tuple, each item in the tuple is an iterable of items.
names: str or Tuple[str], optional
The names of the items. If it is a tuple, each name corresponds to an
item in the tuple.
......
......@@ -183,9 +183,9 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
dst_ids = []
for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert isinstance(minibatch.node_pairs, tuple)
assert minibatch.labels is None
src = minibatch.node_pairs[:, 0]
dst = minibatch.node_pairs[:, 1]
src, dst = minibatch.node_pairs
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -224,11 +224,12 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
labels = []
for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert isinstance(minibatch.node_pairs, tuple)
assert minibatch.labels is not None
assert len(minibatch.node_pairs) == len(minibatch.labels)
src = minibatch.node_pairs[:, 0]
dst = minibatch.node_pairs[:, 1]
src, dst = minibatch.node_pairs
label = minibatch.labels
assert len(src) == len(dst)
assert len(src) == len(label)
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -277,9 +278,9 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
negs_ids = []
for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert isinstance(minibatch.node_pairs, tuple)
assert minibatch.negative_dsts is not None
src = minibatch.node_pairs[:, 0]
dst = minibatch.node_pairs[:, 1]
src, dst = minibatch.node_pairs
negs = minibatch.negative_dsts
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
......@@ -451,9 +452,10 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
assert False
src = []
dst = []
for _, node_pairs in minibatch.node_pairs.items():
src.append(node_pairs[:, 0])
dst.append(node_pairs[:, 1])
for _, (node_pairs) in minibatch.node_pairs.items():
assert isinstance(node_pairs, tuple)
src.append(node_pairs[0])
dst.append(node_pairs[1])
src = torch.cat(src)
dst = torch.cat(dst)
assert len(src) == expected_batch_size
......@@ -510,8 +512,9 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
dst = []
label = []
for _, node_pairs in minibatch.node_pairs.items():
src.append(node_pairs[:, 0])
dst.append(node_pairs[:, 1])
assert isinstance(node_pairs, tuple)
src.append(node_pairs[0])
dst.append(node_pairs[1])
for _, v_label in minibatch.labels.items():
label.append(v_label)
src = torch.cat(src)
......@@ -582,8 +585,9 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
dst = []
negs = []
for _, node_pairs in minibatch.node_pairs.items():
src.append(node_pairs[:, 0])
dst.append(node_pairs[:, 1])
assert isinstance(node_pairs, tuple)
src.append(node_pairs[0])
dst.append(node_pairs[1])
for _, v_negs in minibatch.negative_dsts.items():
negs.append(v_negs)
src = torch.cat(src)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment