Unverified Commit d08075d4 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Support return reverse edge ids in sampling (#6347)

parent 5a7e156f
......@@ -7,12 +7,14 @@ from ..utils import recursive_apply
__all__ = [
"CANONICAL_ETYPE_DELIMITER",
"ORIGINAL_EDGE_ID",
"etype_str_to_tuple",
"etype_tuple_to_str",
"CopyTo",
]
CANONICAL_ETYPE_DELIMITER = ":"
ORIGINAL_EDGE_ID = "_ORIGINAL_EDGE_ID"
def etype_tuple_to_str(c_etype):
......
......@@ -11,7 +11,7 @@ import torch
from ...base import ETYPE
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from .sampled_subgraph_impl import SampledSubgraphImpl
......@@ -230,6 +230,15 @@ class CSCSamplingGraph:
)
row = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
original_edge_ids = C_sampled_subgraph.original_edge_ids
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
original_edge_ids
]
if type_per_edge is None:
# The sampled graph is already a homogeneous graph.
node_pairs = (row, column)
......@@ -237,6 +246,7 @@ class CSCSamplingGraph:
# The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs.
node_pairs = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
......@@ -247,7 +257,13 @@ class CSCSamplingGraph:
column[mask] - self.node_type_offset[dst_ntype_id]
)
node_pairs[etype] = (hetero_row, hetero_column)
return SampledSubgraphImpl(node_pairs=node_pairs)
if has_original_eids:
original_hetero_edge_ids[etype] = original_edge_ids[mask]
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
return SampledSubgraphImpl(
node_pairs=node_pairs, original_edge_ids=original_edge_ids
)
def _convert_to_homogeneous_nodes(self, nodes):
homogeneous_nodes = []
......@@ -329,7 +345,7 @@ class CSCSamplingGraph:
nodes = self._convert_to_homogeneous_nodes(nodes)
C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, False, probs_name
nodes, fanouts, replace, probs_name
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
......@@ -377,7 +393,6 @@ class CSCSamplingGraph:
nodes: torch.Tensor,
fanouts: torch.Tensor,
replace: bool = False,
return_eids: bool = False,
probs_name: Optional[str] = None,
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
......@@ -408,10 +423,6 @@ class CSCSamplingGraph:
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
return_eids: bool
Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is
typically used when edge features are required.
probs_name: str, optional
An optional string specifying the name of an edge attribute. This
attribute tensor should contain (unnormalized) probabilities
......@@ -425,8 +436,12 @@ class CSCSamplingGraph:
"""
# Ensure nodes is 1-D tensor.
self._check_sampler_arguments(nodes, fanouts, probs_name)
has_origin_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
return self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace, False, return_eids, probs_name
nodes, fanouts.tolist(), replace, False, has_origin_eids, probs_name
)
def sample_layer_neighbors(
......@@ -489,8 +504,17 @@ class CSCSamplingGraph:
nodes = self._convert_to_homogeneous_nodes(nodes)
self._check_sampler_arguments(nodes, fanouts, probs_name)
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace, True, False, probs_name
nodes,
fanouts.tolist(),
replace,
True,
has_original_eids,
probs_name,
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
......
......@@ -721,6 +721,107 @@ def test_sample_neighbors_replace(
assert subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_homo(labor):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes = 5
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Add edge id mapping from CSC graph -> original graph.
edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(num_edges)}
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanouts=torch.LongTensor([-1]))
# Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
torch.tensor([3, 4, 7, 8, 9, 10, 11])
]
assert torch.equal(expected_reverse_edge_ids, subgraph.original_edge_ids)
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_hetero(labor):
"""
Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
num_nodes = 5
num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
edge_attributes = {
gb.ORIGINAL_EDGE_ID: torch.cat([torch.randperm(4), torch.randperm(5)])
}
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
edge_attributes=edge_attributes,
metadata=metadata,
)
# Sample on both node types.
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
expected_reverse_edge_ids = {
"n2:e2:n1": edge_attributes[gb.ORIGINAL_EDGE_ID][torch.tensor([0, 1])],
"n1:e1:n2": edge_attributes[gb.ORIGINAL_EDGE_ID][torch.tensor([4, 5])],
}
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
for etype in etypes.keys():
assert torch.equal(
subgraph.original_edge_ids[etype], expected_reverse_edge_ids[etype]
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment