Unverified Commit d08075d4 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Support return reverse edge ids in sampling (#6347)

parent 5a7e156f
...@@ -7,12 +7,14 @@ from ..utils import recursive_apply ...@@ -7,12 +7,14 @@ from ..utils import recursive_apply
__all__ = [ __all__ = [
"CANONICAL_ETYPE_DELIMITER", "CANONICAL_ETYPE_DELIMITER",
"ORIGINAL_EDGE_ID",
"etype_str_to_tuple", "etype_str_to_tuple",
"etype_tuple_to_str", "etype_tuple_to_str",
"CopyTo", "CopyTo",
] ]
CANONICAL_ETYPE_DELIMITER = ":" CANONICAL_ETYPE_DELIMITER = ":"
ORIGINAL_EDGE_ID = "_ORIGINAL_EDGE_ID"
def etype_tuple_to_str(c_etype): def etype_tuple_to_str(c_etype):
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
from ...base import ETYPE from ...base import ETYPE
from ...convert import to_homogeneous from ...convert import to_homogeneous
from ...heterograph import DGLGraph from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from .sampled_subgraph_impl import SampledSubgraphImpl from .sampled_subgraph_impl import SampledSubgraphImpl
...@@ -230,6 +230,15 @@ class CSCSamplingGraph: ...@@ -230,6 +230,15 @@ class CSCSamplingGraph:
) )
row = C_sampled_subgraph.indices row = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge type_per_edge = C_sampled_subgraph.type_per_edge
original_edge_ids = C_sampled_subgraph.original_edge_ids
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
original_edge_ids
]
if type_per_edge is None: if type_per_edge is None:
# The sampled graph is already a homogeneous graph. # The sampled graph is already a homogeneous graph.
node_pairs = (row, column) node_pairs = (row, column)
...@@ -237,6 +246,7 @@ class CSCSamplingGraph: ...@@ -237,6 +246,7 @@ class CSCSamplingGraph:
# The sampled graph is a fused homogenized graph, which need to be # The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs. # converted to heterogeneous graphs.
node_pairs = defaultdict(list) node_pairs = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items(): for etype, etype_id in self.metadata.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype) src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype] src_ntype_id = self.metadata.node_type_to_id[src_ntype]
...@@ -247,7 +257,13 @@ class CSCSamplingGraph: ...@@ -247,7 +257,13 @@ class CSCSamplingGraph:
column[mask] - self.node_type_offset[dst_ntype_id] column[mask] - self.node_type_offset[dst_ntype_id]
) )
node_pairs[etype] = (hetero_row, hetero_column) node_pairs[etype] = (hetero_row, hetero_column)
return SampledSubgraphImpl(node_pairs=node_pairs) if has_original_eids:
original_hetero_edge_ids[etype] = original_edge_ids[mask]
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
return SampledSubgraphImpl(
node_pairs=node_pairs, original_edge_ids=original_edge_ids
)
def _convert_to_homogeneous_nodes(self, nodes): def _convert_to_homogeneous_nodes(self, nodes):
homogeneous_nodes = [] homogeneous_nodes = []
...@@ -329,7 +345,7 @@ class CSCSamplingGraph: ...@@ -329,7 +345,7 @@ class CSCSamplingGraph:
nodes = self._convert_to_homogeneous_nodes(nodes) nodes = self._convert_to_homogeneous_nodes(nodes)
C_sampled_subgraph = self._sample_neighbors( C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, False, probs_name nodes, fanouts, replace, probs_name
) )
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
...@@ -377,7 +393,6 @@ class CSCSamplingGraph: ...@@ -377,7 +393,6 @@ class CSCSamplingGraph:
nodes: torch.Tensor, nodes: torch.Tensor,
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
return_eids: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
) -> torch.ScriptObject: ) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
...@@ -408,10 +423,6 @@ class CSCSamplingGraph: ...@@ -408,10 +423,6 @@ class CSCSamplingGraph:
Boolean indicating whether the sample is preformed with or Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once. times. Otherwise, each value can be selected only once.
return_eids: bool
Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is
typically used when edge features are required.
probs_name: str, optional probs_name: str, optional
An optional string specifying the name of an edge attribute. This An optional string specifying the name of an edge attribute. This
attribute tensor should contain (unnormalized) probabilities attribute tensor should contain (unnormalized) probabilities
...@@ -425,8 +436,12 @@ class CSCSamplingGraph: ...@@ -425,8 +436,12 @@ class CSCSamplingGraph:
""" """
# Ensure nodes is 1-D tensor. # Ensure nodes is 1-D tensor.
self._check_sampler_arguments(nodes, fanouts, probs_name) self._check_sampler_arguments(nodes, fanouts, probs_name)
has_origin_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
return self._c_csc_graph.sample_neighbors( return self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace, False, return_eids, probs_name nodes, fanouts.tolist(), replace, False, has_origin_eids, probs_name
) )
def sample_layer_neighbors( def sample_layer_neighbors(
...@@ -489,8 +504,17 @@ class CSCSamplingGraph: ...@@ -489,8 +504,17 @@ class CSCSamplingGraph:
nodes = self._convert_to_homogeneous_nodes(nodes) nodes = self._convert_to_homogeneous_nodes(nodes)
self._check_sampler_arguments(nodes, fanouts, probs_name) self._check_sampler_arguments(nodes, fanouts, probs_name)
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
C_sampled_subgraph = self._c_csc_graph.sample_neighbors( C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace, True, False, probs_name nodes,
fanouts.tolist(),
replace,
True,
has_original_eids,
probs_name,
) )
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
......
...@@ -721,6 +721,107 @@ def test_sample_neighbors_replace( ...@@ -721,6 +721,107 @@ def test_sample_neighbors_replace(
assert subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2 assert subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_homo(labor):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes = 5
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Add edge id mapping from CSC graph -> original graph.
edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(num_edges)}
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanouts=torch.LongTensor([-1]))
# Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
torch.tensor([3, 4, 7, 8, 9, 10, 11])
]
assert torch.equal(expected_reverse_edge_ids, subgraph.original_edge_ids)
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_hetero(labor):
"""
Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
num_nodes = 5
num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
edge_attributes = {
gb.ORIGINAL_EDGE_ID: torch.cat([torch.randperm(4), torch.randperm(5)])
}
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
edge_attributes=edge_attributes,
metadata=metadata,
)
# Sample on both node types.
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
expected_reverse_edge_ids = {
"n2:e2:n1": edge_attributes[gb.ORIGINAL_EDGE_ID][torch.tensor([0, 1])],
"n1:e1:n2": edge_attributes[gb.ORIGINAL_EDGE_ID][torch.tensor([4, 5])],
}
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
for etype in etypes.keys():
assert torch.equal(
subgraph.original_edge_ids[etype], expected_reverse_edge_ids[etype]
)
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment