Unverified Commit d873acc2 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add more APIs for SamplingGraph (#6719)

parent e6f78c10
......@@ -170,6 +170,55 @@ class FusedCSCSamplingGraph(SamplingGraph):
return num_nodes_per_type
@property
def num_edges(self) -> Union[int, Dict[str, int]]:
"""The number of edges in the graph.
- If the graph is homogenous, returns an integer.
- If the graph is heterogenous, returns a dictionary.
Returns
-------
Union[int, Dict[str, int]]
The number of edges. Integer indicates the total edges number of a
homogenous graph; dict indicates edges number per edge types of a
heterogenous graph.
Examples
--------
>>> import dgl.graphbolt as gb, torch
>>> total_num_nodes = 5
>>> total_num_edges = 12
>>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
... "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> print(graph.num_edges)
{'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3}
"""
type_per_edge = self.type_per_edge
# Homogenous.
if type_per_edge is None or self.edge_type_to_id is None:
return self._c_csc_graph.num_edges()
# Heterogenous
bincount = torch.bincount(type_per_edge)
num_edges_per_type = {}
for etype, etype_id in self.edge_type_to_id.items():
if etype_id < len(bincount):
num_edges_per_type[etype] = bincount[etype_id].item()
else:
num_edges_per_type[etype] = 0
return num_edges_per_type
@property
def csc_indptr(self) -> torch.tensor:
"""Returns the indices pointer in the CSC graph.
......
......@@ -2,6 +2,8 @@
from typing import Dict, Union
import torch
__all__ = ["SamplingGraph"]
......@@ -12,6 +14,16 @@ class SamplingGraph:
def __init__(self):
pass
def __repr__(self) -> str:
"""Return a string representation of the graph.
Returns
-------
str
String representation of the graph.
"""
raise NotImplementedError
@property
def num_nodes(self) -> Union[int, Dict[str, int]]:
"""The number of nodes in the graph.
......@@ -26,3 +38,49 @@ class SamplingGraph:
heterogenous graph.
"""
raise NotImplementedError
@property
def num_edges(self) -> Union[int, Dict[str, int]]:
"""The number of edges in the graph.
- If the graph is homogenous, returns an integer.
- If the graph is heterogenous, returns a dictionary.
Returns
-------
Union[int, Dict[str, int]]
The number of edges. Integer indicates the total edges number of a
homogenous graph; dict indicates edges number per edge types of a
heterogenous graph.
"""
raise NotImplementedError
def copy_to_shared_memory(self, shared_memory_name: str) -> "SamplingGraph":
"""Copy the graph to shared memory.
Parameters
----------
shared_memory_name : str
Name of the shared memory.
Returns
-------
SamplingGraph
The copied SamplingGraph object on shared memory.
"""
raise NotImplementedError
# pylint: disable=invalid-name
def to(self, device: torch.device) -> "SamplingGraph":
"""Copy graph to the specified device.
Parameters
----------
device : torch.device
The destination device.
Returns
-------
SamplingGraph
The graph on the specified device.
"""
raise NotImplementedError
......@@ -187,7 +187,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
"total_num_nodes, total_num_edges",
[(1, 1), (100, 1), (10, 50), (1000, 50000)],
)
def test_num_nodes_homo(total_num_nodes, total_num_edges):
def test_num_nodes_edges_homo(total_num_nodes, total_num_edges):
csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges
)
......@@ -200,6 +200,7 @@ def test_num_nodes_homo(total_num_nodes, total_num_edges):
)
assert graph.num_nodes == total_num_nodes
assert graph.num_edges == total_num_edges
@unittest.skipIf(
......@@ -233,6 +234,7 @@ def test_num_nodes_hetero():
"N0:R1:N1": 1,
"N1:R2:N0": 2,
"N1:R3:N1": 3,
"N1:R4:N0": 4,
}
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
......@@ -254,9 +256,16 @@ def test_num_nodes_hetero():
"N0": 2,
"N1": 3,
}
assert graph.num_nodes["N0"] == 2
assert graph.num_nodes["N1"] == 3
assert "N2" not in graph.num_nodes
assert sum(graph.num_nodes.values()) == total_num_nodes
# Verify edges number per edge types.
assert graph.num_edges == {
"N0:R0:N0": 2,
"N0:R1:N1": 4,
"N1:R2:N0": 3,
"N1:R3:N1": 3,
"N1:R4:N0": 0,
}
assert sum(graph.num_edges.values()) == total_num_edges
@unittest.skipIf(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment