"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8ef0d9deff84bf70a6a07ddbaa5cedac345f0f67"
Unverified Commit 884a378c authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Update docstring in `FusedCSCSamplingGraph`. (#6837)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 749ac593
...@@ -369,11 +369,15 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -369,11 +369,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])} >>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes) >>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.sampled_csc) >>> print(in_subgraph.sampled_csc)
defaultdict(<class 'list'>, { {'N0:R0:N0': CSCFormatBase(indptr=tensor([0, 0]),
'N0:R0:N0': (tensor([]), tensor([])), indices=tensor([], dtype=torch.int64),
'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])), ), 'N0:R1:N1': CSCFormatBase(indptr=tensor([0, 1, 2]),
'N1:R2:N0': (tensor([0, 1]), tensor([1, 1])), indices=tensor([1, 0]),
'N1:R3:N1': (tensor([0, 1, 2]), tensor([1, 2, 2]))} ), 'N1:R2:N0': CSCFormatBase(indptr=tensor([0, 2]),
indices=tensor([0, 1]),
), 'N1:R3:N1': CSCFormatBase(indptr=tensor([0, 1, 3]),
indices=tensor([0, 1, 2]),
)}
""" """
if isinstance(nodes, dict): if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes) nodes = self._convert_to_homogeneous_nodes(nodes)
...@@ -598,7 +602,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -598,7 +602,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
FusedSampledSubgraphImpl Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]
The sampled subgraph. The sampled subgraph.
Examples Examples
...@@ -620,8 +624,11 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -620,8 +624,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> fanouts = torch.tensor([1, 1]) >>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_neighbors(nodes, fanouts) >>> subgraph = graph.sample_neighbors(nodes, fanouts)
>>> print(subgraph.sampled_csc) >>> print(subgraph.sampled_csc)
defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([0]), {'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 1]),
tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))}) indices=tensor([0]),
), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1]),
indices=tensor([2]),
)}
""" """
if isinstance(nodes, dict): if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes) nodes = self._convert_to_homogeneous_nodes(nodes)
...@@ -812,8 +819,11 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -812,8 +819,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> fanouts = torch.tensor([1, 1]) >>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_layer_neighbors(nodes, fanouts) >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
>>> print(subgraph.sampled_csc) >>> print(subgraph.sampled_csc)
defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([1]), {'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 1]),
tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))}) indices=tensor([0]),
), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1]),
indices=tensor([2]),
)}
""" """
if isinstance(nodes, dict): if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes) nodes = self._convert_to_homogeneous_nodes(nodes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment