"tests/vscode:/vscode.git/clone" did not exist on "b28ab30215b908a414ca9ad84d8fcda0aea45ed5"
Unverified Commit 06dc1dc4 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[doc] update docstring for gb.sample_neighbor (#6740)

parent ad1b9269
......@@ -67,19 +67,26 @@ class NeighborSampler(SubgraphSampler):
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=1,)
>>> neg_sampler = gb.UniformNegativeSampler(
... item_sampler, graph, 2)
>>> subgraph_sampler = gb.NeighborSampler(
... neg_sampler, graph, [5, 10, 15])
>>> for data in subgraph_sampler:
... print(data.compacted_node_pairs)
... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
3
(tensor([0, 0, 0]), tensor([1, 1, 1]))
3
>>> datapipe = gb.ItemSampler(item_set, batch_size=1)
>>> datapipe = datapipe.sample_uniform_negative(graph, 2)
>>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15])
>>> next(iter(datapipe)).sampled_subgraphs
[FusedSampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 3, 2, 4, 5]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 3, 2, 4, 5]),
node_pairs=(tensor([1, 3, 0, 2, 4, 5, 2, 5]),
tensor([0, 0, 1, 1, 2, 3, 4, 5])),),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 3, 2, 4, 5]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 3, 2, 4]),
node_pairs=(tensor([1, 3, 0, 2, 4, 5, 2]),
tensor([0, 0, 1, 1, 2, 3, 4])),),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([0, 1, 3, 2, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 3]),
node_pairs=(tensor([1, 3, 0, 2, 4]),
tensor([0, 0, 1, 1, 2])),
)]
"""
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment