Unverified Commit 17198e9e authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Rename num_nodes and num_edges. (#6388)

parent 297e120f
......@@ -87,7 +87,7 @@ class CSCSamplingGraph:
self._metadata = metadata
@property
def num_nodes(self) -> int:
def total_num_nodes(self) -> int:
"""Returns the number of nodes in the graph.
Returns
......@@ -98,7 +98,7 @@ class CSCSamplingGraph:
return self._c_csc_graph.num_nodes()
@property
def num_edges(self) -> int:
def total_num_edges(self) -> int:
"""Returns the number of edges in the graph.
Returns
......@@ -116,7 +116,7 @@ class CSCSamplingGraph:
-------
torch.tensor
The indices pointer in the CSC graph. An integer tensor with
shape `(num_nodes+1,)`.
shape `(total_num_nodes+1,)`.
"""
return self._c_csc_graph.csc_indptr()
......@@ -128,7 +128,7 @@ class CSCSamplingGraph:
-------
torch.tensor
The indices in the CSC graph. An integer tensor with shape
`(num_edges,)`.
`(total_num_edges,)`.
Notes
-------
......@@ -161,7 +161,7 @@ class CSCSamplingGraph:
Returns
-------
torch.Tensor or None
If present, returns a 1D integer tensor of shape (num_edges,)
If present, returns a 1D integer tensor of shape (total_num_edges,)
containing the type of each edge in the graph.
"""
return self._c_csc_graph.type_per_edge()
......@@ -377,7 +377,7 @@ class CSCSamplingGraph:
probs_or_mask = self.edge_attributes[probs_name]
assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
assert (
probs_or_mask.size(0) == self.num_edges
probs_or_mask.size(0) == self.total_num_edges
), "Probs should have the same number of elements as the number \
of edges."
assert probs_or_mask.dtype in [
......@@ -566,7 +566,7 @@ class CSCSamplingGraph:
- self.node_type_offset[dst_node_type_id]
)
else:
max_node_id = self.num_nodes
max_node_id = self.total_num_nodes
return self._c_csc_graph.sample_negative_edges_uniform(
node_pairs,
negative_ratio,
......@@ -606,10 +606,10 @@ def from_csc(
----------
csc_indptr : torch.Tensor
Pointer to the start of each row in the `indices`. An integer tensor
with shape `(num_nodes+1,)`.
with shape `(total_num_nodes+1,)`.
indices : torch.Tensor
Column indices of the non-zero elements in the CSC graph. An integer
tensor with shape `(num_edges,)`.
tensor with shape `(total_num_edges,)`.
node_type_offset : Optional[torch.tensor], optional
Offset of node types in the graph, by default None.
type_per_edge : Optional[torch.tensor], optional
......@@ -637,7 +637,7 @@ def from_csc(
>>> print(graph)
CSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
num_nodes=3, num_edges=7)
total_num_nodes=3, total_num_edges=7)
"""
if metadata and metadata.node_type_to_id and node_type_offset is not None:
assert len(metadata.node_type_to_id) + 1 == node_type_offset.size(
......@@ -683,7 +683,10 @@ def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str:
"""
csc_indptr_str = str(graph.csc_indptr)
indices_str = str(graph.indices)
meta_str = f"num_nodes={graph.num_nodes}, num_edges={graph.num_edges}"
meta_str = (
f"total_num_nodes={graph.total_num_nodes}, total_num_edges="
f"{graph.total_num_edges}"
)
prefix = f"{type(graph).__name__}("
def _add_indent(_str, indent):
......
......@@ -961,8 +961,8 @@ def test_OnDiskDataset_Graph_homogeneous():
dataset = gb.OnDiskDataset(test_dir).load()
graph2 = dataset.graph
assert graph.num_nodes == graph2.num_nodes
assert graph.num_edges == graph2.num_edges
assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges
assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices)
......@@ -1004,8 +1004,8 @@ def test_OnDiskDataset_Graph_heterogeneous():
dataset = gb.OnDiskDataset(test_dir).load()
graph2 = dataset.graph
assert graph.num_nodes == graph2.num_nodes
assert graph.num_edges == graph2.num_edges
assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges
assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices)
......@@ -1076,8 +1076,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
csc_sampling_graph = gb.csc_sampling_graph.load_csc_sampling_graph(
os.path.join(test_dir, processed_dataset["graph_topology"]["path"])
)
assert csc_sampling_graph.num_nodes == num_nodes
assert csc_sampling_graph.num_edges == num_edges
assert csc_sampling_graph.total_num_nodes == num_nodes
assert csc_sampling_graph.total_num_edges == num_edges
num_samples = 100
fanout = 1
......
......@@ -7,8 +7,8 @@ from torchdata.datapipes.iter import Mapper
def test_FeatureFetcher_invoke():
# Prepare graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(20, 0.15)
a = torch.randint(0, 10, (graph.num_nodes,))
b = torch.randint(0, 10, (graph.num_edges,))
a = torch.randint(0, 10, (graph.total_num_nodes,))
b = torch.randint(0, 10, (graph.total_num_edges,))
features = {}
keys = [("node", None, "a"), ("edge", None, "b")]
......@@ -35,8 +35,8 @@ def test_FeatureFetcher_invoke():
def test_FeatureFetcher_homo():
graph = gb_test_utils.rand_csc_graph(20, 0.15)
a = torch.randint(0, 10, (graph.num_nodes,))
b = torch.randint(0, 10, (graph.num_edges,))
a = torch.randint(0, 10, (graph.total_num_nodes,))
b = torch.randint(0, 10, (graph.total_num_edges,))
features = {}
keys = [("node", None, "a"), ("edge", None, "b")]
......@@ -56,8 +56,8 @@ def test_FeatureFetcher_homo():
def test_FeatureFetcher_with_edges_homo():
graph = gb_test_utils.rand_csc_graph(20, 0.15)
a = torch.randint(0, 10, (graph.num_nodes,))
b = torch.randint(0, 10, (graph.num_edges,))
a = torch.randint(0, 10, (graph.total_num_nodes,))
b = torch.randint(0, 10, (graph.total_num_edges,))
def add_node_and_edge_ids(seeds):
subgraphs = []
......@@ -65,7 +65,9 @@ def test_FeatureFetcher_with_edges_homo():
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])),
original_edge_ids=torch.randint(0, graph.num_edges, (10,)),
original_edge_ids=torch.randint(
0, graph.total_num_edges, (10,)
),
)
)
data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment