"tutorials/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "c55ab2d147c0d424e9abcb7521d67dc5a2e793fb"
Unverified Commit 5ce42668 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify docstring of `SampledSubgraph`. (#6821)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 6cb8c662
...@@ -22,13 +22,11 @@ class SampledSubgraph: ...@@ -22,13 +22,11 @@ class SampledSubgraph:
@property @property
def node_pairs( def node_pairs(
self, self,
) -> Union[ ) -> Union[CSCFormatBase, Dict[str, CSCFormatBase],]:
Tuple[torch.Tensor, torch.Tensor], """Returns the node pairs representing edges in csc format.
Dict[str, Tuple[torch.Tensor, torch.Tensor]], - If `node_pairs` is a CSCFormatBase: It should be in the csc format.
]: `indptr` stores the index in the data array where each column
"""Returns the node pairs representing source-destination edges. starts. `indices` stores the row indices of the non-zero elements.
- If `node_pairs` is a tuple: It should be in the format ('u', 'v')
representing source and destination pairs.
- If `node_pairs` is a dictionary: The keys should be edge type and - If `node_pairs` is a dictionary: The keys should be edge type and
the values should be corresponding node pairs. The ids inside the values should be corresponding node pairs. The ids inside
is heterogeneous ids.""" is heterogeneous ids."""
...@@ -119,27 +117,33 @@ class SampledSubgraph: ...@@ -119,27 +117,33 @@ class SampledSubgraph:
Examples Examples
-------- --------
>>> node_pairs = {"A:relation:B": (torch.tensor([0, 1, 2]), >>> import dgl.graphbolt as gb
... torch.tensor([0, 1, 2]))} >>> import torch
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])} >>> node_pairs = {"A:relation:B": gb.CSCFormatBase(
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])} ... indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {"B": torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {"A": torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.FusedSampledSubgraphImpl( >>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs, ... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids, ... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids, ... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids ... original_edge_ids=original_edge_ids
... ) ... )
>>> edges_to_exclude = (torch.tensor([14, 15]), torch.tensor([11, 12])) >>> edges_to_exclude = {"A:relation:B": (torch.tensor([14, 15]),
... torch.tensor([11, 12]))}
>>> result = subgraph.exclude_edges(edges_to_exclude) >>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.node_pairs) >>> print(result.node_pairs)
{"A:relation:B": (tensor([0]), tensor([0]))} {'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]),
indices=tensor([0]),
)}
>>> print(result.original_column_node_ids) >>> print(result.original_column_node_ids)
{'B': tensor([10, 11, 12])} {'B': tensor([10, 11, 12])}
>>> print(result.original_row_node_ids) >>> print(result.original_row_node_ids)
{'A': tensor([13, 14, 15])} {'A': tensor([13, 14, 15])}
>>> print(result.original_edge_ids) >>> print(result.original_edge_ids)
{"A:relation:B": tensor([19])} {'A:relation:B': tensor([19])}
""" """
# TODO: Add support for value > in32, then remove this line. # TODO: Add support for value > in32, then remove this line.
assert ( assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment