Unverified Commit 1785acff authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Refactor csc format sampled subgraph. (#6553)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent fdeda8a8
"""Base types and utilities for Graph Bolt."""
from dataclasses import dataclass
import torch
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
......@@ -13,6 +15,7 @@ __all__ = [
"etype_tuple_to_str",
"CopyTo",
"isin",
"CSCFormatBase",
]
CANONICAL_ETYPE_DELIMITER = ":"
......@@ -111,3 +114,21 @@ class CopyTo(IterDataPipe):
for data in self.datapipe:
data = recursive_apply(data, apply_to, self.device)
yield data
@dataclass
class CSCFormatBase:
r"""Basic class representing data in Compressed Sparse Column (CSC) format.
Examples
--------
>>> indptr = torch.tensor([0, 1, 3])
>>> indices = torch.tensor([1, 4, 2])
>>> csc_foramt_base = CSCFormatBase(indptr=indptr, indices=indices)
>>> print(csc_format_base.indptr)
... torch.tensor([0, 1, 3])
>>> print(csc_foramt_base)
... torch.tensor([1, 4, 2])
"""
indptr: torch.Tensor = None
indices: torch.Tensor = None
......@@ -15,7 +15,11 @@ from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..sampling_graph import SamplingGraph
from .sampled_subgraph_impl import FusedSampledSubgraphImpl
from .sampled_subgraph_impl import (
CSCFormatBase,
FusedSampledSubgraphImpl,
SampledSubgraphImpl,
)
__all__ = [
......@@ -342,9 +346,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes)
return self._convert_to_sampled_subgraph(_in_subgraph)
return self._convert_to_fused_sampled_subgraph(_in_subgraph)
def _convert_to_sampled_subgraph(
def _convert_to_fused_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
):
......@@ -400,13 +404,109 @@ class FusedCSCSamplingGraph(SamplingGraph):
homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
return torch.cat(homogeneous_nodes)
def _convert_to_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
) -> SampledSubgraphImpl:
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'SampledSubgraphImpl'."""
indptr = C_sampled_subgraph.indptr
indices = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
column = C_sampled_subgraph.original_column_node_ids
original_edge_ids = C_sampled_subgraph.original_edge_ids
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
original_edge_ids
]
if type_per_edge is None:
# The sampled graph is already a homogeneous graph.
node_pairs = CSCFormatBase(indptr=indptr, indices=indices)
else:
# The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs.
# Pre-calculate the number of each etype
num = {}
for etype in type_per_edge:
num[etype.item()] = num.get(etype.item(), 0) + 1
# Preallocate
subgraph_indice_position = {}
subgraph_indice = {}
subgraph_indptr = {}
node_edge_type = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items():
subgraph_indice[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=indices.dtype
)
if has_original_eids:
original_hetero_edge_ids[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=original_edge_ids.dtype
)
subgraph_indptr[etype] = [0]
subgraph_indice_position[etype] = 0
# Preprocessing saves the type of seed_nodes as the edge type
# of dst_ntype.
_, _, dst_ntype = etype_str_to_tuple(etype)
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
node_edge_type[dst_ntype_id].append((etype, etype_id))
# construct subgraphs
for (i, seed) in enumerate(column):
l = indptr[i].item()
r = indptr[i + 1].item()
node_type = (
torch.searchsorted(
self.node_type_offset, seed, right=True
).item()
- 1
)
for (etype, etype_id) in node_edge_type[node_type]:
src_ntype, _, _ = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
num_edges = torch.searchsorted(
type_per_edge[l:r], etype_id, right=True
).item()
end = num_edges + l
subgraph_indptr[etype].append(
subgraph_indptr[etype][-1] + num_edges
)
offset = subgraph_indice_position[etype]
subgraph_indice_position[etype] += num_edges
subgraph_indice[etype][offset : offset + num_edges] = (
indices[l:end] - self.node_type_offset[src_ntype_id]
)
if has_original_eids:
original_hetero_edge_ids[etype][
offset : offset + num_edges
] = original_edge_ids[l:end]
l = end
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
node_pairs = {
etype: CSCFormatBase(
indptr=torch.tensor(subgraph_indptr[etype]),
indices=subgraph_indice[etype],
)
for etype in self.metadata.edge_type_to_id.keys()
}
return SampledSubgraphImpl(
node_pairs=node_pairs,
original_edge_ids=original_edge_ids,
)
def sample_neighbors(
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
) -> FusedSampledSubgraphImpl:
deduplicate=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -476,8 +576,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, probs_name
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
if deduplicate is True:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _check_sampler_arguments(self, nodes, fanouts, probs_name):
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
......@@ -584,7 +686,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
) -> FusedSampledSubgraphImpl:
deduplicate=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
......@@ -667,7 +770,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
probs_name,
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
if deduplicate:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def sample_negative_edges_uniform(
self, edge_type, node_pairs, negative_ratio
......
......@@ -115,6 +115,7 @@ class NeighborSampler(SubgraphSampler):
self.fanouts[hop],
self.replace,
self.prob_name,
self.deduplicate,
)
if self.deduplicate:
(
......
......@@ -5,10 +5,10 @@ from typing import Dict, Tuple, Union
import torch
from ..base import etype_str_to_tuple
from ..base import CSCFormatBase, etype_str_to_tuple
from ..sampled_subgraph import SampledSubgraph
__all__ = ["FusedSampledSubgraphImpl"]
__all__ = ["FusedSampledSubgraphImpl", "SampledSubgraphImpl"]
@dataclass
......@@ -67,3 +67,65 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
assert all(
isinstance(item, torch.Tensor) for item in self.node_pairs
), "Nodes in pairs should be of type torch.Tensor."
@dataclass
class SampledSubgraphImpl(SampledSubgraph):
r"""Sampled subgraph of CSCSamplingGraph.
Examples
--------
>>> node_pairs = {"A:relation:B": CSCFormatBase(indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> print(subgraph.node_pairs)
{"A:relation:B": CSCForamtBase(indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> print(subgraph.original_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(subgraph.original_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])}
"""
node_pairs: Union[
CSCFormatBase,
Dict[str, CSCFormatBase],
] = None
original_column_node_ids: Union[
Dict[str, torch.Tensor], torch.Tensor
] = None
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self):
if isinstance(self.node_pairs, dict):
for etype, pair in self.node_pairs.items():
assert (
isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3
), "Edge type should be a string in format of str:str:str."
assert (
pair.indptr is not None and pair.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(pair.indptr, torch.Tensor) and isinstance(
pair.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
else:
assert (
self.node_pairs.indptr is not None
and self.node_pairs.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(
self.node_pairs.indptr, torch.Tensor
) and isinstance(
self.node_pairs.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
......@@ -1680,3 +1680,660 @@ def test_csc_sampling_graph_to_device():
assert graph.csc_indptr.device.type == "cuda"
for key in graph.edge_attributes:
assert graph.edge_attributes[key].device.type == "cuda"
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_sample_neighbors_homo_csc_format():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(
nodes, fanouts=torch.LongTensor([2]), deduplicate=False
)
# Verify in subgraph.
sampled_indptr_num = subgraph.node_pairs.indptr.size(0)
sampled_num = subgraph.node_pairs.indices.size(0)
assert sampled_indptr_num == 4
assert sampled_num == 6
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero_csc_format(labor):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
)
# Sample on both node types.
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False)
# Verify in subgraph.
expected_node_pairs = {
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 2]),
indices=torch.LongTensor([0, 1]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 2]),
indices=torch.LongTensor([0, 2]),
),
}
assert len(subgraph.node_pairs) == 2
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype].indptr, pairs.indptr)
assert torch.equal(subgraph.node_pairs[etype].indices, pairs.indices)
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
# Sample on single node type.
nodes = {"n1": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False)
# Verify in subgraph.
expected_node_pairs = {
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.LongTensor([0]),
indices=torch.LongTensor([]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 2]),
indices=torch.LongTensor([0, 2]),
),
}
assert len(subgraph.node_pairs) == 2
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype].indptr, pairs.indptr)
assert torch.equal(subgraph.node_pairs[etype].indices, pairs.indices)
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"fanouts, expected_sampled_num1, expected_sampled_num2",
[
([0], 0, 0),
([1], 1, 1),
([2], 2, 2),
([4], 2, 2),
([-1], 2, 2),
([0, 0], 0, 0),
([1, 0], 1, 0),
([0, 1], 0, 1),
([1, 1], 1, 1),
([2, 1], 2, 1),
([-1, -1], 2, 2),
],
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_fanouts_csc_format(
fanouts, expected_sampled_num1, expected_sampled_num2, labor
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
)
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.LongTensor(fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False)
# Verify in subgraph.
assert (
expected_sampled_num1 == 0
or subgraph.node_pairs["n1:e1:n2"].indices.numel()
== expected_sampled_num1
)
assert subgraph.node_pairs["n1:e1:n2"].indptr.size(0) == 2
assert (
expected_sampled_num2 == 0
or subgraph.node_pairs["n2:e2:n1"].indices.numel()
== expected_sampled_num2
)
assert subgraph.node_pairs["n2:e2:n1"].indptr.size(0) == 2
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"replace, expected_sampled_num1, expected_sampled_num2",
[(False, 2, 2), (True, 4, 4)],
)
def test_sample_neighbors_replace_csc_format(
replace, expected_sampled_num1, expected_sampled_num2
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
)
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
subgraph = graph.sample_neighbors(
nodes, torch.LongTensor([4]), replace=replace, deduplicate=False
)
# Verify in subgraph.
assert (
subgraph.node_pairs["n1:e1:n2"].indices.numel() == expected_sampled_num1
)
assert subgraph.node_pairs["n1:e1:n2"].indptr.size(0) == 2
assert (
subgraph.node_pairs["n2:e2:n1"].indices.numel() == expected_sampled_num2
)
assert subgraph.node_pairs["n2:e2:n1"].indptr.size(0) == 2
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_homo_csc_format(labor):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Add edge id mapping from CSC graph -> original graph.
edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)}
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts=torch.LongTensor([-1]), deduplicate=False)
# Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
torch.tensor([3, 4, 7, 8, 9, 10, 11])
]
assert torch.equal(expected_reverse_edge_ids, subgraph.original_edge_ids)
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_hetero_csc_format(labor):
"""
Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
edge_attributes = {
gb.ORIGINAL_EDGE_ID: torch.cat([torch.randperm(4), torch.randperm(5)])
}
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
edge_attributes=edge_attributes,
metadata=metadata,
)
# Sample on both node types.
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False)
# Verify in subgraph.
expected_reverse_edge_ids = {
"n2:e2:n1": edge_attributes[gb.ORIGINAL_EDGE_ID][torch.tensor([0, 1])],
"n1:e1:n2": edge_attributes[gb.ORIGINAL_EDGE_ID][torch.tensor([4, 5])],
}
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
for etype in etypes.keys():
assert torch.equal(
subgraph.original_edge_ids[etype], expected_reverse_edge_ids[etype]
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("probs_name", ["weight", "mask"])
def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
edge_attributes = {
"weight": torch.FloatTensor(
[2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5, 0.4, 1.2]
),
"mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]),
}
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(
nodes,
fanouts=torch.tensor([2]),
replace=replace,
probs_name=probs_name,
deduplicate=False,
)
# Verify in subgraph.
sampled_num = subgraph.node_pairs.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 4
if replace:
assert sampled_num == 6
else:
assert sampled_num == 4
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"probs_or_mask",
[
torch.zeros(12, dtype=torch.float32),
torch.zeros(12, dtype=torch.bool),
],
)
def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
edge_attributes = {"probs_or_mask": probs_or_mask}
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(
nodes,
fanouts=torch.tensor([5]),
replace=replace,
probs_name="probs_or_mask",
deduplicate=False,
)
# Verify in subgraph.
sampled_num = subgraph.node_pairs.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 4
assert sampled_num == 0
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"fanouts, probs_name",
[
([2], "mask"),
([3], "mask"),
([4], "mask"),
([-1], "mask"),
([7], "mask"),
([3], "all"),
([-1], "all"),
([7], "all"),
([3], "zero"),
([-1], "zero"),
([3], "none"),
([-1], "none"),
],
)
def test_sample_neighbors_homo_pick_number_csc_format(
fanouts, replace, labor, probs_name
):
"""Original graph in COO:
1 1 1 1 1 1
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
"""
# Initialize data.
total_num_nodes = 6
total_num_edges = 6
indptr = torch.LongTensor([0, 6, 6, 6, 6, 6, 6])
indices = torch.LongTensor([0, 1, 2, 3, 4, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
edge_attributes = {
"mask": torch.BoolTensor([1, 0, 0, 1, 0, 1]),
"all": torch.BoolTensor([1, 1, 1, 1, 1, 1]),
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0]),
}
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# Make sure no exception will be thrown.
subgraph = sampler(
nodes,
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
deduplicate=False,
)
sampled_num = subgraph.node_pairs.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 3
# Verify in subgraph.
if probs_name == "mask":
if fanouts[0] == -1:
assert sampled_num == 3
else:
if replace:
assert sampled_num == fanouts[0]
else:
assert sampled_num == min(fanouts[0], 3)
elif probs_name == "zero":
assert sampled_num == 0
else:
if fanouts[0] == -1:
assert sampled_num == 6
else:
if replace:
assert sampled_num == fanouts[0]
else:
assert sampled_num == min(fanouts[0], 6)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"fanouts, probs_name",
[
([-1, -1, -1], "mask"),
([1, 1, 1], "mask"),
([2, 2, 2], "mask"),
([3, 3, 3], "mask"),
([4, 4, 4], "mask"),
([-1, 1, 3], "none"),
([2, -1, 4], "none"),
],
)
def test_sample_neighbors_hetero_pick_number_csc_format(
fanouts, replace, labor, probs_name
):
# Initialize data.
total_num_nodes = 10
total_num_edges = 9
ntypes = {"N0": 0, "N1": 1, "N2": 2, "N3": 3}
etypes = {
"N1:R0:N0": 0,
"N2:R1:N0": 1,
"N3:R2:N0": 2,
}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
type_per_edge = torch.LongTensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes))
edge_attributes = {
"mask": torch.BoolTensor([1, 1, 0, 1, 1, 1, 0, 0, 0]),
"all": torch.BoolTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),
}
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
edge_attributes=edge_attributes,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# Make sure no exception will be thrown.
subgraph = sampler(
nodes,
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
deduplicate=False,
)
print(subgraph)
if probs_name == "none":
for etype, pairs in subgraph.node_pairs.items():
assert pairs.indptr.size(0) == 2
sampled_num = pairs.indices.size(0)
fanout = fanouts[etypes[etype]]
if fanout == -1:
assert sampled_num == 3
else:
if replace:
assert sampled_num == fanout
else:
assert sampled_num == min(fanout, 3)
else:
fanout = fanouts[0] # Here fanout is the same for all etypes.
for etype, pairs in subgraph.node_pairs.items():
assert pairs.indptr.size(0) == 2
sampled_num = pairs.indices.size(0)
if etypes[etype] == 0:
# Etype 0: 2 valid neighbors.
if fanout == -1:
assert sampled_num == 2
else:
if replace:
assert sampled_num == fanout
else:
assert sampled_num == min(fanout, 2)
elif etypes[etype] == 1:
# Etype 1: 3 valid neighbors.
if fanout == -1:
assert sampled_num == 3
else:
if replace:
assert sampled_num == fanout
else:
assert sampled_num == min(fanout, 3)
else:
# Etype 2: 0 valid neighbors.
assert sampled_num == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment