"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "3feb5021b49f33e5112eeb1458143d5c8a55f17f"
Unverified Commit 47c6fb1f authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Fix docstring example (#6328)

parent e3bf1c0e
...@@ -61,18 +61,15 @@ class NeighborSampler(SubgraphSampler): ...@@ -61,18 +61,15 @@ class NeighborSampler(SubgraphSampler):
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]]) >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...item_sampler, 2, data_format, graph) ...item_sampler, graph, 2)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.NeighborSampler( >>> subgraph_sampler = gb.NeighborSampler(
...neg_sampler, graph, fanouts) ...neg_sampler, graph, [5, 10, 15])
>>> for data in subgraph_sampler: >>> for data in subgraph_sampler:
... print(data.compacted_node_pairs) ... print(data.compacted_node_pairs)
... print(len(data.sampled_subgraphs)) ... print(len(data.sampled_subgraphs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment