Unverified Commit 5a24d02d authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Update docstring examples of `NeighborSampler` and `LayerNeighborSampler`. (#6820)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 60a27bc6
...@@ -73,21 +73,24 @@ class NeighborSampler(SubgraphSampler): ...@@ -73,21 +73,24 @@ class NeighborSampler(SubgraphSampler):
>>> datapipe = datapipe.sample_uniform_negative(graph, 2) >>> datapipe = datapipe.sample_uniform_negative(graph, 2)
>>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15]) >>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15])
>>> next(iter(datapipe)).sampled_subgraphs >>> next(iter(datapipe)).sampled_subgraphs
[FusedSampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 3, 2, 4, 5]), [SampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 3, 4, 2, 5]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 3, 2, 4, 5]), original_column_node_ids=tensor([0, 1, 3, 4, 2, 5]),
node_pairs=(tensor([1, 3, 0, 2, 4, 5, 2, 5]), node_pairs=CSCFormatBase(indptr=tensor([0, 2, 4, 5, 6, 7, 8]),
tensor([0, 0, 1, 1, 2, 3, 4, 5])),), indices=tensor([1, 4, 0, 2, 3, 2, 5, 5]),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 3, 2, 4, 5]), ),
original_edge_ids=None, ), SampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 3, 4, 2, 5]),
original_column_node_ids=tensor([0, 1, 3, 2, 4]), original_edge_ids=None,
node_pairs=(tensor([1, 3, 0, 2, 4, 5, 2]), original_column_node_ids=tensor([0, 1, 3, 4, 2]),
tensor([0, 0, 1, 1, 2, 3, 4])),), node_pairs=CSCFormatBase(indptr=tensor([0, 2, 4, 5, 6, 7]),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 3, 2, 4]), indices=tensor([1, 4, 0, 2, 3, 2, 5]),
original_edge_ids=None, ),
original_column_node_ids=tensor([0, 1, 3]), ), SampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 3, 4, 2]),
node_pairs=(tensor([1, 3, 0, 2, 4]), original_edge_ids=None,
tensor([0, 0, 1, 1, 2])), original_column_node_ids=tensor([0, 1, 3, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 2, 4, 5, 6]),
indices=tensor([1, 4, 0, 2, 3, 2]),
),
)] )]
""" """
...@@ -237,28 +240,39 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -237,28 +240,39 @@ class LayerNeighborSampler(NeighborSampler):
Examples Examples
------- -------
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> from dgl import graphbolt as gb >>> import torch
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices) >>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]]) >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(item_set, batch_size=1,)
... item_set, batch_size=1,) >>> neg_sampler = gb.UniformNegativeSampler(item_sampler, graph, 2)
>>> neg_sampler = gb.UniformNegativeSampler( >>> fanouts = [torch.LongTensor([5]),
... item_sampler, 2, data_format, graph) ... torch.LongTensor([10]),torch.LongTensor([15])]
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]), >>> subgraph_sampler = gb.LayerNeighborSampler(neg_sampler, graph, fanouts)
... torch.LongTensor([15])] >>> next(iter(subgraph_sampler)).sampled_subgraphs
>>> subgraph_sampler = gb.LayerNeighborSampler( [SampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 2, 3, 5, 4]),
... neg_sampler, graph, fanouts) original_edge_ids=None,
>>> for data in subgraph_sampler: original_column_node_ids=tensor([0, 1, 2, 3, 5, 4]),
... print(data.compacted_node_pairs) node_pairs=CSCFormatBase(indptr=tensor([0, 2, 4, 5, 6, 7, 8]),
... print(len(data.sampled_subgraphs)) indices=tensor([1, 2, 0, 3, 4, 5, 4, 3]),
(tensor([0, 0, 0]), tensor([1, 0, 2])) ),
3 ), SampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 2, 3, 5, 4]),
(tensor([0, 0, 0]), tensor([1, 1, 1])) original_edge_ids=None,
3 original_column_node_ids=tensor([0, 1, 2, 3, 5]),
node_pairs=CSCFormatBase(indptr=tensor([0, 2, 4, 5, 6, 7]),
indices=tensor([1, 2, 0, 3, 4, 5, 4]),
),
), SampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 2, 3, 5]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 2]),
node_pairs=CSCFormatBase(indptr=tensor([0, 2, 4, 5]),
indices=tensor([1, 2, 0, 3, 4]),
),
)]
>>> next(iter(subgraph_sampler)).compacted_node_pairs
(tensor([0]), tensor([1]))
""" """
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment