Unverified Commit ba9c1521 authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Update docstring related to cleaning up `seed_nodes` and `node_pairs`. (#7341)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 22516834
......@@ -34,7 +34,7 @@ class InSubgraphSampler(SubgraphSampler):
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> item_set = gb.ItemSet(len(indptr) - 1, names="seed_nodes")
>>> item_set = gb.ItemSet(len(indptr) - 1, names="seeds")
>>> item_sampler = gb.ItemSampler(item_set, batch_size=2)
>>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
>>> for _, data in enumerate(insubgraph_sampler):
......
......@@ -407,8 +407,8 @@ class NeighborSampler(NeighborSamplerImpl):
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> seeds = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(seeds, names="seeds")
>>> datapipe = gb.ItemSampler(item_set, batch_size=1)
>>> datapipe = datapipe.sample_uniform_negative(graph, 2)
>>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15])
......@@ -534,8 +534,8 @@ class LayerNeighborSampler(NeighborSamplerImpl):
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> seeds = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(seeds, names="seeds")
>>> item_sampler = gb.ItemSampler(item_set, batch_size=1,)
>>> neg_sampler = gb.UniformNegativeSampler(item_sampler, graph, 2)
>>> fanouts = [torch.LongTensor([5]),
......@@ -566,8 +566,12 @@ class LayerNeighborSampler(NeighborSamplerImpl):
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 5, 2]),
)]
>>> next(iter(subgraph_sampler)).compacted_node_pairs
(tensor([0]), tensor([1]))
>>> next(iter(subgraph_sampler)).compacted_seeds
tensor([[0, 1], [0, 2], [0, 3]])
>>> next(iter(subgraph_sampler)).labels
tensor([1., 0., 0.])
>>> next(iter(subgraph_sampler)).indexes
tensor([0, 0, 0])
"""
def __init__(
......
......@@ -42,11 +42,7 @@ from .torch_based_feature_store import TorchBasedFeatureStore
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset", "BuiltinDataset"]
NAMES_INDICATING_NODE_IDS = [
"seed_nodes",
"node_pairs",
"seeds",
"negative_srcs",
"negative_dsts",
]
......
......@@ -36,20 +36,20 @@ class UniformNegativeSampler(NegativeSampler):
>>> indptr = torch.LongTensor([0, 1, 2, 3, 4])
>>> indices = torch.LongTensor([1, 2, 3, 0])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> seeds = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])
>>> item_set = gb.ItemSet(seeds, names="seeds")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4,)
>>> neg_sampler = gb.UniformNegativeSampler(
... item_sampler, graph, 2)
>>> for minibatch in neg_sampler:
... print(minibatch.negative_srcs)
... print(minibatch.negative_dsts)
None
tensor([[2, 1],
[2, 1],
[3, 2],
[1, 3]])
... print(minibatch.seeds)
... print(minibatch.labels)
... print(minibatch.indexes)
tensor([[0, 1], [1, 2], [2, 3], [3, 0], [0, 1], [0, 3], [1, 1], [1, 2],
[2, 1], [2, 0], [3, 0], [3, 2]])
tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3])
"""
def __init__(
......
......@@ -50,7 +50,7 @@ def minibatcher_default(batch, names):
return batch
if len(names) == 1:
# Handle the case of single item: batch = tensor([0, 1, 2, 3]), names =
# ("seed_nodes",) as `zip(batch, names)` will iterate over the tensor
# ("seeds",) as `zip(batch, names)` will iterate over the tensor
# instead of the batch.
init_data = {names[0]: batch}
else:
......@@ -313,68 +313,61 @@ class ItemSampler(IterDataPipe):
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seeds")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=tensor([0, 1, 2, 3]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
MiniBatch(seeds=tensor([0, 1, 2, 3]), sampled_subgraphs=None,
node_features=None, labels=None, input_nodes=None,
indexes=None, edge_features=None, compacted_seeds=None,
blocks=None,)
2. Node pairs.
>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs")
... names="seeds")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),
sampled_subgraphs=None, node_features=None, labels=None,
input_nodes=None, indexes=None, edge_features=None,
compacted_seeds=None, blocks=None,)
3. Node pairs and labels.
>>> item_set = gb.ItemSet(
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 20)),
... names=("node_pairs", "labels")
... names=("seeds", "labels")
... )
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=tensor([10, 11, 12, 13]), negative_srcs=None,
negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
4. Node pairs and negative destinations.
>>> node_pairs = torch.arange(0, 20).reshape(-1, 2)
>>> negative_dsts = torch.arange(10, 30).reshape(-1, 2)
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts"))
MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),
sampled_subgraphs=None, node_features=None,
labels=tensor([10, 11, 12, 13]), input_nodes=None,
indexes=None, edge_features=None, compacted_seeds=None,
blocks=None,)
4. Node pairs, labels and indexes.
>>> seeds = torch.arange(0, 20).reshape(-1, 2)
>>> labels = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0])
>>> indexes = torch.tensor([0, 1, 0, 0, 0, 0, 1, 1, 1, 1])
>>> item_set = gb.ItemSet((seeds, labels, indexes), names=("seeds",
... "labels", "indexes"))
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
labels=None, negative_srcs=None,
negative_dsts=tensor([[10, 11],
[12, 13],
[14, 15],
[16, 17]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),
sampled_subgraphs=None, node_features=None,
labels=tensor([1, 1, 0, 0]), input_nodes=None,
indexes=tensor([0, 1, 0, 0]), edge_features=None,
compacted_seeds=None, blocks=None,)
5. DGLGraphs.
......@@ -404,85 +397,74 @@ class ItemSampler(IterDataPipe):
7. Heterogeneous node IDs.
>>> ids = {
... "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
... "item": gb.ItemSet(torch.arange(0, 6), names="seed_nodes"),
... "user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
... "item": gb.ItemSet(torch.arange(0, 6), names="seeds"),
... }
>>> item_set = gb.ItemSetDict(ids)
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes={'user': tensor([0, 1, 2, 3])}, node_pairs=None,
labels=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
MiniBatch(seeds={'user': tensor([0, 1, 2, 3])}, sampled_subgraphs=None,
node_features=None, labels=None, input_nodes=None, indexes=None,
edge_features=None, compacted_seeds=None, blocks=None,)
8. Heterogeneous node pairs.
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> seeds_like = torch.arange(0, 10).reshape(-1, 2)
>>> seeds_follow = torch.arange(10, 20).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(
... node_pairs_like, names="node_pairs"),
... seeds_like, names="seeds"),
... "user:follow:user": gb.ItemSet(
... node_pairs_follow, names="node_pairs"),
... seeds_follow, names="seeds"),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
MiniBatch(seeds={'user:like:item':
tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,
node_features=None, labels=None, input_nodes=None, indexes=None,
edge_features=None, compacted_seeds=None, blocks=None,)
9. Heterogeneous node pairs and labels.
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.arange(0, 10)
>>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> labels_follow = torch.arange(10, 20)
>>> seeds_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.arange(0, 5)
>>> seeds_follow = torch.arange(10, 20).reshape(-1, 2)
>>> labels_follow = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet((node_pairs_like, labels_like),
... names=("node_pairs", "labels")),
... "user:follow:user": gb.ItemSet((node_pairs_follow, labels_follow),
... names=("node_pairs", "labels")),
... "user:like:item": gb.ItemSet((seeds_like, labels_like),
... names=("seeds", "labels")),
... "user:follow:user": gb.ItemSet((seeds_follow, labels_follow),
... names=("seeds", "labels")),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
labels={'user:like:item': tensor([0, 1, 2, 3])},
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
10. Heterogeneous node pairs and negative destinations.
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> negative_dsts_like = torch.arange(10, 20).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(20, 30).reshape(-1, 2)
>>> negative_dsts_follow = torch.arange(30, 40).reshape(-1, 2)
MiniBatch(seeds={'user:like:item':
tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,
node_features=None, labels={'user:like:item': tensor([0, 1, 2, 3])},
input_nodes=None, indexes=None, edge_features=None,
compacted_seeds=None, blocks=None,)
10. Heterogeneous node pairs, labels and indexes.
>>> seeds_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.tensor([1, 1, 0, 0, 0])
>>> indexes_like = torch.tensor([0, 1, 0, 0, 1])
>>> seeds_follow = torch.arange(20, 30).reshape(-1, 2)
>>> labels_follow = torch.tensor([1, 1, 0, 0, 0])
>>> indexes_follow = torch.tensor([0, 1, 0, 0, 1])
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet((node_pairs_like, negative_dsts_like),
... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet((node_pairs_follow,
... negative_dsts_follow), names=("node_pairs", "negative_dsts")),
... "user:like:item": gb.ItemSet((seeds_like, labels_like,
... indexes_like), names=("seeds", "labels", "indexes")),
... "user:follow:user": gb.ItemSet((seeds_follow,labels_follow,
... indexes_follow), names=("seeds", "labels", "indexes")),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
node_pairs={'user:like:item':
(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
labels=None, negative_srcs=None,
negative_dsts={'user:like:item': tensor([[10, 11],
[12, 13],
[14, 15],
[16, 17]])}, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
MiniBatch(seeds={'user:like:item':
tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,
node_features=None, labels={'user:like:item': tensor([1, 1, 0, 0])},
input_nodes=None, indexes={'user:like:item': tensor([0, 1, 0, 0])},
edge_features=None, compacted_seeds=None, blocks=None,)
"""
def __init__(
......
......@@ -26,11 +26,11 @@ class MiniBatch:
labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Labels associated with seed nodes / node pairs in the graph.
Labels associated with seeds in the graph.
- If `labels` is a tensor: It indicates the graph is homogeneous. The value
should be corresponding labels to given 'seed_nodes' or 'node_pairs'.
should be corresponding labels to given 'seeds'.
- If `labels` is a dictionary: The keys should be node or edge type and the
value should be corresponding labels to given 'seed_nodes' or 'node_pairs'.
value should be corresponding labels to given 'seeds'.
"""
seeds: Union[
......@@ -61,15 +61,14 @@ class MiniBatch:
indexes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Indexes associated with seed nodes / node pairs in the graph, which
indicates to which query a seed node / node pair belongs.
Indexes associated with seeds in the graph, which
indicates to which query a seeds belongs.
- If `indexes` is a tensor: It indicates the graph is homogeneous. The
value should be corresponding query to given 'seed_nodes' or
'node_pairs'.
- If `indexes` is a dictionary: It indicates the graph is
heterogeneous. The keys should be node or edge type and the value should
be corresponding query to given 'seed_nodes' or 'node_pairs'. For each
key, indexes are consecutive integers starting from zero.
value should be corresponding query to given 'seeds'.
- If `indexes` is a dictionary: It indicates the graph is heterogeneous.
The keys should be node or edge type and the value should be
corresponding query to given 'seeds'. For each key, indexes are
consecutive integers starting from zero.
"""
sampled_subgraphs: List[SampledSubgraph] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment