You need to sign in or sign up before continuing.
Unverified Commit 5ce42668 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify docstring of `SampledSubgraph`. (#6821)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 6cb8c662
......@@ -22,13 +22,11 @@ class SampledSubgraph:
@property
def node_pairs(
self,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
]:
"""Returns the node pairs representing source-destination edges.
- If `node_pairs` is a tuple: It should be in the format ('u', 'v')
representing source and destination pairs.
) -> Union[CSCFormatBase, Dict[str, CSCFormatBase],]:
"""Returns the node pairs representing edges in csc format.
- If `node_pairs` is a CSCFormatBase: It should be in the csc format.
`indptr` stores the index in the data array where each column
starts. `indices` stores the row indices of the non-zero elements.
- If `node_pairs` is a dictionary: The keys should be edge type and
the values should be corresponding node pairs. The ids inside
is heterogeneous ids."""
......@@ -119,27 +117,33 @@ class SampledSubgraph:
Examples
--------
>>> node_pairs = {"A:relation:B": (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> import dgl.graphbolt as gb
>>> import torch
>>> node_pairs = {"A:relation:B": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {"B": torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {"A": torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.FusedSampledSubgraphImpl(
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> edges_to_exclude = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> edges_to_exclude = {"A:relation:B": (torch.tensor([14, 15]),
... torch.tensor([11, 12]))}
>>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.node_pairs)
{"A:relation:B": (tensor([0]), tensor([0]))}
{'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]),
indices=tensor([0]),
)}
>>> print(result.original_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(result.original_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(result.original_edge_ids)
{"A:relation:B": tensor([19])}
{'A:relation:B': tensor([19])}
"""
# TODO: Add support for value > in32, then remove this line.
assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment