Unverified Commit fa1ae3b7 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove coo from `FusedCSCSamplingGraph`. (#6849)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent f81e514d
...@@ -12,11 +12,7 @@ from ...convert import to_homogeneous ...@@ -12,11 +12,7 @@ from ...convert import to_homogeneous
from ...heterograph import DGLGraph from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..sampling_graph import SamplingGraph from ..sampling_graph import SamplingGraph
from .sampled_subgraph_impl import ( from .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl
CSCFormatBase,
FusedSampledSubgraphImpl,
SampledSubgraphImpl,
)
__all__ = [ __all__ = [
...@@ -324,9 +320,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -324,9 +320,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def in_subgraph( def in_subgraph(
self, self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]], nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
# TODO: clean up once the migration is done. ) -> SampledSubgraphImpl:
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Return the subgraph induced on the inbound edges of the given nodes. """Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming An in subgraph is equivalent to creating a new graph using the incoming
...@@ -344,7 +338,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -344,7 +338,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
FusedSampledSubgraphImpl SampledSubgraphImpl
The in subgraph. The in subgraph.
Examples Examples
...@@ -389,60 +383,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -389,60 +383,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "Nodes cannot have duplicate values." ), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes) _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(_in_subgraph)
else:
return self._convert_to_sampled_subgraph(_in_subgraph) return self._convert_to_sampled_subgraph(_in_subgraph)
def _convert_to_fused_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
):
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'FusedSampledSubgraphImpl'."""
column_num = (
C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
)
column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
column_num
)
row = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
original_edge_ids = C_sampled_subgraph.original_edge_ids
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
original_edge_ids
]
if type_per_edge is None:
# The sampled graph is already a homogeneous graph.
node_pairs = (row, column)
else:
# The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs.
node_pairs = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.node_type_to_id[src_ntype]
dst_ntype_id = self.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id
hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = (
column[mask] - self.node_type_offset[dst_ntype_id]
)
node_pairs[etype] = (hetero_row, hetero_column)
if has_original_eids:
original_hetero_edge_ids[etype] = original_edge_ids[mask]
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
return FusedSampledSubgraphImpl(
sampled_csc=node_pairs, original_edge_ids=original_edge_ids
)
def _convert_to_homogeneous_nodes(self, nodes, timestamps=None): def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
homogeneous_nodes = [] homogeneous_nodes = []
homogeneous_timestamps = [] homogeneous_timestamps = []
...@@ -560,9 +502,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -560,9 +502,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
# TODO: clean up once the migration is done. ) -> SampledSubgraphImpl:
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph. subgraph.
...@@ -604,7 +544,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -604,7 +544,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
Union[FusedSampledSubgraphImpl, SampledSubgraphImpl] SampledSubgraphImpl
The sampled subgraph. The sampled subgraph.
Examples Examples
...@@ -638,9 +578,6 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -638,9 +578,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph = self._sample_neighbors( C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, probs_name nodes, fanouts, replace, probs_name
) )
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _check_sampler_arguments(self, nodes, fanouts, probs_name): def _check_sampler_arguments(self, nodes, fanouts, probs_name):
...@@ -753,9 +690,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -753,9 +690,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
# TODO: clean up once the migration is done. ) -> SampledSubgraphImpl:
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
...@@ -799,7 +734,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -799,7 +734,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
FusedSampledSubgraphImpl SampledSubgraphImpl
The sampled subgraph. The sampled subgraph.
Examples Examples
...@@ -843,10 +778,6 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -843,10 +778,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
has_original_eids, has_original_eids,
probs_name, probs_name,
) )
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def temporal_sample_neighbors( def temporal_sample_neighbors(
...@@ -906,7 +837,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -906,7 +837,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
FusedSampledSubgraphImpl SampledSubgraphImpl
The sampled subgraph. The sampled subgraph.
""" """
if isinstance(nodes, dict): if isinstance(nodes, dict):
......
...@@ -683,137 +683,6 @@ def test_multiprocessing(): ...@@ -683,137 +683,6 @@ def test_multiprocessing():
p.join() p.join()
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_node_pairs_homogeneous():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(indptr, indices)
# Extract in subgraph.
nodes = torch.LongTensor([4, 1, 3])
in_subgraph = graph.in_subgraph(nodes, output_cscformat=False)
# Verify in subgraph.
assert torch.equal(
in_subgraph.sampled_csc[0], torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
)
assert torch.equal(
in_subgraph.sampled_csc[1], torch.LongTensor([4, 4, 4, 1, 1, 3, 3])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([9, 10, 11, 3, 4, 7, 8])
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_node_pairs_heterogeneous():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
node_type_0: [0, 1]
node_type_1: [2, 3, 4]
edge_type_0: node_type_0 -> node_type_0
edge_type_1: node_type_0 -> node_type_1
edge_type_2: node_type_1 -> node_type_0
edge_type_3: node_type_1 -> node_type_1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
ntypes = {
"N0": 0,
"N1": 1,
}
etypes = {
"N0:R0:N0": 0,
"N0:R1:N1": 1,
"N1:R2:N0": 2,
"N1:R3:N1": 3,
}
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
node_type_offset = torch.LongTensor([0, 2, 5])
type_per_edge = torch.LongTensor([0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes))
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Extract in subgraph.
nodes = {
"N0": torch.LongTensor([1]),
"N1": torch.LongTensor([2, 1]),
}
in_subgraph = graph.in_subgraph(nodes, output_cscformat=False)
# Verify in subgraph.
assert torch.equal(
in_subgraph.sampled_csc["N0:R0:N0"][0], torch.LongTensor([])
)
assert torch.equal(
in_subgraph.sampled_csc["N0:R0:N0"][1], torch.LongTensor([])
)
assert torch.equal(
in_subgraph.sampled_csc["N0:R1:N1"][0], torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.sampled_csc["N0:R1:N1"][1], torch.LongTensor([2, 1])
)
assert torch.equal(
in_subgraph.sampled_csc["N1:R2:N0"][0], torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.sampled_csc["N1:R2:N0"][1], torch.LongTensor([1, 1])
)
assert torch.equal(
in_subgraph.sampled_csc["N1:R3:N1"][0], torch.LongTensor([1, 2, 0])
)
assert torch.equal(
in_subgraph.sampled_csc["N1:R3:N1"][1], torch.LongTensor([2, 2, 1])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 9, 10, 11, 7, 8])
)
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
...@@ -839,7 +708,7 @@ def test_in_subgraph_homo(): ...@@ -839,7 +708,7 @@ def test_in_subgraph_homo():
# Extract in subgraph. # Extract in subgraph.
nodes = torch.LongTensor([4, 1, 3]) nodes = torch.LongTensor([4, 1, 3])
in_subgraph = graph.in_subgraph(nodes, output_cscformat=True) in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal( assert torch.equal(
...@@ -911,7 +780,7 @@ def test_in_subgraph_hetero(): ...@@ -911,7 +780,7 @@ def test_in_subgraph_hetero():
"N0": torch.LongTensor([1]), "N0": torch.LongTensor([1]),
"N1": torch.LongTensor([2, 1]), "N1": torch.LongTensor([2, 1]),
} }
in_subgraph = graph.in_subgraph(nodes, output_cscformat=True) in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal( assert torch.equal(
...@@ -945,64 +814,6 @@ def test_in_subgraph_hetero(): ...@@ -945,64 +814,6 @@ def test_in_subgraph_hetero():
) )
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)
indices = torch.tensor(
[0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype
)
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert len(indptr) == total_num_nodes + 1
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(indptr, indices)
# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([2])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = torch.tensor(
[1, 3, 4],
dtype=(torch.int64 if indices_dtype == torch.int32 else torch.int32),
)
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype)
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
sampled_num = subgraph.sampled_csc[0].size(0)
assert sampled_num == 6
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
...@@ -1125,10 +936,14 @@ def test_temporal_sample_neighbors_homo( ...@@ -1125,10 +936,14 @@ def test_temporal_sample_neighbors_homo(
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): @pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("use_node_timestamp", [False, True])
@pytest.mark.parametrize("use_edge_timestamp", [False, True])
def test_temporal_sample_neighbors_hetero(
indptr_dtype, indices_dtype, replace, use_node_timestamp, use_edge_timestamp
):
"""Original graph in COO: """Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2] "n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0] "n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
...@@ -1141,6 +956,8 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -1141,6 +956,8 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
ntypes_to_offset = {"n1": 0, "n2": 2}
total_num_nodes = 5
total_num_edges = 9 total_num_edges = 9
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype) indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype) indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
...@@ -1159,561 +976,97 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -1159,561 +976,97 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
edge_type_to_id=etypes, edge_type_to_id=etypes,
) )
# Sample on both node types. # Generate subgraph via sample neighbors.
fanouts = torch.tensor([-1, -1]) fanouts = torch.LongTensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.temporal_sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = {
"n1": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
),
"n2": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
),
}
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts, output_cscformat=False)
# 2. Sample with nodes in matched dtype with graph's indices. seeds = {
nodes = {
"n1": torch.tensor([0], dtype=indices_dtype), "n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype), "n2": torch.tensor([0], dtype=indices_dtype),
} }
subgraph = sampler(nodes, fanouts, output_cscformat=False) per_etype_destination_nodes = {
"n1:e1:n2": torch.tensor([1], dtype=indices_dtype),
# Verify in subgraph. "n2:e2:n1": torch.tensor([0], dtype=indices_dtype),
expected_node_pairs = {
"n1:e1:n2": (
torch.LongTensor([0, 1]),
torch.LongTensor([0, 0]),
),
"n2:e2:n1": (
torch.LongTensor([0, 2]),
torch.LongTensor([0, 0]),
),
}
assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.sampled_csc[etype][0], pairs[0])
assert torch.equal(subgraph.sampled_csc[etype][1], pairs[1])
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
# Sample on single node type.
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = {
"n1": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
)
} }
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts, output_cscformat=False)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = {"n1": torch.tensor([0], dtype=indices_dtype)}
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph. seed_timestamp = {
expected_node_pairs = { "n1": torch.randint(0, 100, (1,), dtype=torch.int64),
"n2:e2:n1": ( "n2": torch.randint(0, 100, (1,), dtype=torch.int64),
torch.tensor([0, 2], dtype=indices_dtype),
torch.tensor([0, 0], dtype=indices_dtype),
),
"n1:e1:n2": (
torch.tensor([], dtype=indices_dtype),
torch.tensor([], dtype=indices_dtype),
),
} }
assert len(subgraph.sampled_csc) == 2 if use_node_timestamp:
for etype, pairs in expected_node_pairs.items(): node_timestamp = torch.randint(
assert torch.equal(subgraph.sampled_csc[etype][0], pairs[0]) 0, 100, (total_num_nodes,), dtype=torch.int64
assert torch.equal(subgraph.sampled_csc[etype][1], pairs[1]) )
assert subgraph.original_column_node_ids is None graph.node_attributes = {"timestamp": node_timestamp}
assert subgraph.original_row_node_ids is None if use_edge_timestamp:
assert subgraph.original_edge_ids is None edge_timestamp = torch.randint(
0, 100, (total_num_edges,), dtype=torch.int64
)
graph.edge_attributes = {"timestamp": edge_timestamp}
subgraph, neighbors_timestamp = sampler(
@unittest.skipIf( seeds,
F._default_context_str == "gpu", seed_timestamp,
reason="Graph is CPU only at present.", fanouts,
) replace=replace,
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64]) node_timestamp_attr_name="timestamp" if use_node_timestamp else None,
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64]) edge_timestamp_attr_name="timestamp" if use_edge_timestamp else None,
@pytest.mark.parametrize("replace", [False, True]) )
@pytest.mark.parametrize("use_node_timestamp", [False, True])
@pytest.mark.parametrize("use_edge_timestamp", [False, True])
def test_temporal_sample_neighbors_hetero(
indptr_dtype, indices_dtype, replace, use_node_timestamp, use_edge_timestamp
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
ntypes_to_offset = {"n1": 0, "n2": 2}
total_num_nodes = 5
total_num_edges = 9
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([-1, -1])
sampler = graph.temporal_sample_neighbors
seeds = {
"n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype),
}
per_etype_destination_nodes = {
"n1:e1:n2": torch.tensor([1], dtype=indices_dtype),
"n2:e2:n1": torch.tensor([0], dtype=indices_dtype),
}
seed_timestamp = {
"n1": torch.randint(0, 100, (1,), dtype=torch.int64),
"n2": torch.randint(0, 100, (1,), dtype=torch.int64),
}
if use_node_timestamp:
node_timestamp = torch.randint(
0, 100, (total_num_nodes,), dtype=torch.int64
)
graph.node_attributes = {"timestamp": node_timestamp}
if use_edge_timestamp:
edge_timestamp = torch.randint(
0, 100, (total_num_edges,), dtype=torch.int64
)
graph.edge_attributes = {"timestamp": edge_timestamp}
subgraph, neighbors_timestamp = sampler(
seeds,
seed_timestamp,
fanouts,
replace=replace,
node_timestamp_attr_name="timestamp" if use_node_timestamp else None,
edge_timestamp_attr_name="timestamp" if use_edge_timestamp else None,
)
def _to_homo(): def _to_homo():
ret_seeds, ret_timestamps = [], [] ret_seeds, ret_timestamps = [], []
for ntype, nodes in seeds.items(): for ntype, nodes in seeds.items():
ntype_id = ntypes[ntype] ntype_id = ntypes[ntype]
offset = node_type_offset[ntype_id] offset = node_type_offset[ntype_id]
ret_seeds.append(nodes + offset) ret_seeds.append(nodes + offset)
ret_timestamps.append(seed_timestamp[ntype]) ret_timestamps.append(seed_timestamp[ntype])
return torch.cat(ret_seeds), torch.cat(ret_timestamps) return torch.cat(ret_seeds), torch.cat(ret_timestamps)
homo_seeds, homo_seed_timestamp = _to_homo()
def _get_available_neighbors():
available_neighbors = []
for i, seed in enumerate(homo_seeds):
neighbors = []
start = indptr[seed].item()
end = indptr[seed + 1].item()
for j in range(start, end):
neighbor = indices[j].item()
if (
use_node_timestamp
and (
node_timestamp[neighbor] > homo_seed_timestamp[i]
).item()
):
continue
if (
use_edge_timestamp
and (edge_timestamp[j] > homo_seed_timestamp[i]).item()
):
continue
neighbors.append(neighbor)
available_neighbors.append(neighbors)
return available_neighbors
available_neighbors = _get_available_neighbors()
sampled_count = [0] * homo_seeds.numel()
sampled_neighbors = [[] for _ in range(homo_seeds.numel())]
for etype, csc in subgraph.sampled_csc.items():
stype, _, _ = etype_str_to_tuple(etype)
ntype_offset = ntypes_to_offset[stype]
dest_nodes = per_etype_destination_nodes[etype]
for i in range(dest_nodes.numel()):
l = csc.indptr[i]
r = csc.indptr[i + 1]
seed_offset = dest_nodes[i].item()
sampled_neighbors[seed_offset].extend(
(csc.indices[l:r] + ntype_offset).tolist()
)
sampled_count[seed_offset] += r - l
for i, count in enumerate(sampled_count):
assert count == len(available_neighbors[i])
assert set(sampled_neighbors[i]).issubset(set(available_neighbors[i]))
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"fanouts, expected_sampled_num1, expected_sampled_num2",
[
([0], 0, 0),
([1], 1, 1),
([2], 2, 2),
([4], 2, 2),
([-1], 2, 2),
([0, 0], 0, 0),
([1, 0], 1, 0),
([0, 1], 0, 1),
([1, 1], 1, 1),
([2, 1], 2, 1),
([-1, -1], 2, 2),
],
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_fanouts(
fanouts, expected_sampled_num1, expected_sampled_num2, labor
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.LongTensor(fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
assert (
expected_sampled_num1 == 0
or subgraph.sampled_csc["n1:e1:n2"][0].numel() == expected_sampled_num1
)
assert (
expected_sampled_num2 == 0
or subgraph.sampled_csc["n2:e2:n1"][0].numel() == expected_sampled_num2
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"replace, expected_sampled_num1, expected_sampled_num2",
[(False, 2, 2), (True, 4, 4)],
)
def test_sample_neighbors_replace(
replace, expected_sampled_num1, expected_sampled_num2
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
subgraph = graph.sample_neighbors(
nodes, torch.LongTensor([4]), replace=replace, output_cscformat=False
)
# Verify in subgraph.
assert subgraph.sampled_csc["n1:e1:n2"][0].numel() == expected_sampled_num1
assert subgraph.sampled_csc["n2:e2:n1"][0].numel() == expected_sampled_num2
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_homo(labor):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Add edge id mapping from CSC graph -> original graph.
edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)}
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(
nodes, fanouts=torch.LongTensor([-1]), output_cscformat=False
)
# Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
torch.tensor([3, 4, 7, 8, 9, 10, 11])
]
assert torch.equal(expected_reverse_edge_ids, subgraph.original_edge_ids)
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_hetero(labor):
"""
Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
edge_attributes = {
gb.ORIGINAL_EDGE_ID: torch.cat([torch.randperm(4), torch.randperm(5)])
}
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
edge_attributes=edge_attributes,
)
# Sample on both node types.
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
expected_reverse_edge_ids = {
"n2:e2:n1": edge_attributes[gb.ORIGINAL_EDGE_ID][torch.tensor([0, 1])],
"n1:e1:n2": edge_attributes[gb.ORIGINAL_EDGE_ID][torch.tensor([4, 5])],
}
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
for etype in etypes.keys():
assert torch.equal(
subgraph.original_edge_ids[etype], expected_reverse_edge_ids[etype]
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("probs_name", ["weight", "mask"])
def test_sample_neighbors_probs(replace, labor, probs_name):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
edge_attributes = {
"weight": torch.FloatTensor(
[2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5, 0.4, 1.2]
),
"mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]),
}
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(
nodes,
fanouts=torch.tensor([2]),
replace=replace,
probs_name=probs_name,
output_cscformat=False,
)
# Verify in subgraph.
sampled_num = subgraph.sampled_csc[0].size(0)
if replace:
assert sampled_num == 6
else:
assert sampled_num == 4
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"probs_or_mask",
[
torch.zeros(12, dtype=torch.float32),
torch.zeros(12, dtype=torch.bool),
],
)
def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
# Initialize data.
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
edge_attributes = {"probs_or_mask": probs_or_mask} homo_seeds, homo_seed_timestamp = _to_homo()
# Construct FusedCSCSamplingGraph. def _get_available_neighbors():
graph = gb.fused_csc_sampling_graph( available_neighbors = []
indptr, indices, edge_attributes=edge_attributes for i, seed in enumerate(homo_seeds):
) neighbors = []
start = indptr[seed].item()
end = indptr[seed + 1].item()
for j in range(start, end):
neighbor = indices[j].item()
if (
use_node_timestamp
and (
node_timestamp[neighbor] > homo_seed_timestamp[i]
).item()
):
continue
if (
use_edge_timestamp
and (edge_timestamp[j] > homo_seed_timestamp[i]).item()
):
continue
neighbors.append(neighbor)
available_neighbors.append(neighbors)
return available_neighbors
# Generate subgraph via sample neighbors. available_neighbors = _get_available_neighbors()
nodes = torch.LongTensor([1, 3, 4]) sampled_count = [0] * homo_seeds.numel()
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampled_neighbors = [[] for _ in range(homo_seeds.numel())]
subgraph = sampler( for etype, csc in subgraph.sampled_csc.items():
nodes, stype, _, _ = etype_str_to_tuple(etype)
fanouts=torch.tensor([5]), ntype_offset = ntypes_to_offset[stype]
replace=replace, dest_nodes = per_etype_destination_nodes[etype]
probs_name="probs_or_mask", for i in range(dest_nodes.numel()):
output_cscformat=False, l = csc.indptr[i]
r = csc.indptr[i + 1]
seed_offset = dest_nodes[i].item()
sampled_neighbors[seed_offset].extend(
(csc.indices[l:r] + ntype_offset).tolist()
) )
sampled_count[seed_offset] += r - l
# Verify in subgraph. for i, count in enumerate(sampled_count):
sampled_num = subgraph.sampled_csc[0].size(0) assert count == len(available_neighbors[i])
assert sampled_num == 0 assert set(sampled_neighbors[i]).issubset(set(available_neighbors[i]))
def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor): def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
...@@ -2158,199 +1511,6 @@ def test_from_dglgraph_heterogeneous(): ...@@ -2158,199 +1511,6 @@ def test_from_dglgraph_heterogeneous():
} }
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"fanouts, probs_name",
[
([2], "mask"),
([3], "mask"),
([4], "mask"),
([-1], "mask"),
([7], "mask"),
([3], "all"),
([-1], "all"),
([7], "all"),
([3], "zero"),
([-1], "zero"),
([3], "none"),
([-1], "none"),
],
)
def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
"""Original graph in COO:
1 1 1 1 1 1
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
"""
# Initialize data.
total_num_nodes = 6
total_num_edges = 6
indptr = torch.LongTensor([0, 6, 6, 6, 6, 6, 6])
indices = torch.LongTensor([0, 1, 2, 3, 4, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
edge_attributes = {
"mask": torch.BoolTensor([1, 0, 0, 1, 0, 1]),
"all": torch.BoolTensor([1, 1, 1, 1, 1, 1]),
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0]),
}
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# Make sure no exception will be thrown.
subgraph = sampler(
nodes,
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
output_cscformat=False,
)
sampled_num = subgraph.sampled_csc[0].size(0)
# Verify in subgraph.
if probs_name == "mask":
if fanouts[0] == -1:
assert sampled_num == 3
else:
if replace:
assert sampled_num == fanouts[0]
else:
assert sampled_num == min(fanouts[0], 3)
elif probs_name == "zero":
assert sampled_num == 0
else:
if fanouts[0] == -1:
assert sampled_num == 6
else:
if replace:
assert sampled_num == fanouts[0]
else:
assert sampled_num == min(fanouts[0], 6)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"fanouts, probs_name",
[
([-1, -1, -1], "mask"),
([1, 1, 1], "mask"),
([2, 2, 2], "mask"),
([3, 3, 3], "mask"),
([4, 4, 4], "mask"),
([-1, 1, 3], "none"),
([2, -1, 4], "none"),
],
)
def test_sample_neighbors_hetero_pick_number(
fanouts, replace, labor, probs_name
):
# Initialize data.
total_num_nodes = 10
total_num_edges = 9
ntypes = {"N0": 0, "N1": 1, "N2": 2, "N3": 3}
etypes = {
"N0:R0:N1": 0,
"N0:R1:N2": 1,
"N0:R2:N3": 2,
}
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
type_per_edge = torch.LongTensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes))
edge_attributes = {
"mask": torch.BoolTensor([1, 1, 0, 1, 1, 1, 0, 0, 0]),
"all": torch.BoolTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),
}
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
edge_attributes=edge_attributes,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# Make sure no exception will be thrown.
subgraph = sampler(
nodes,
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
output_cscformat=False,
)
if probs_name == "none":
for etype, pairs in subgraph.sampled_csc.items():
fanout = fanouts[etypes[etype]]
if fanout == -1:
assert pairs[0].size(0) == 3
else:
if replace:
assert pairs[0].size(0) == fanout
else:
assert pairs[0].size(0) == min(fanout, 3)
else:
fanout = fanouts[0] # Here fanout is the same for all etypes.
for etype, pairs in subgraph.sampled_csc.items():
if etypes[etype] == 0:
# Etype 0: 2 valid neighbors.
if fanout == -1:
assert pairs[0].size(0) == 2
else:
if replace:
assert pairs[0].size(0) == fanout
else:
assert pairs[0].size(0) == min(fanout, 2)
elif etypes[etype] == 1:
# Etype 1: 3 valid neighbors.
if fanout == -1:
assert pairs[0].size(0) == 3
else:
if replace:
assert pairs[0].size(0) == fanout
else:
assert pairs[0].size(0) == min(fanout, 3)
else:
# Etype 2: 0 valid neighbors.
assert pairs[0].size(0) == 0
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "cpu", F._default_context_str == "cpu",
reason="`to` function needs GPU to test.", reason="`to` function needs GPU to test.",
...@@ -2408,7 +1568,7 @@ def test_csc_sampling_graph_to_device(): ...@@ -2408,7 +1568,7 @@ def test_csc_sampling_graph_to_device():
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
def test_sample_neighbors_homo_csc_format(): def test_sample_neighbors_homo():
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 1 0 1 0 1
1 0 1 1 0 1 0 1 1 0
...@@ -2428,9 +1588,7 @@ def test_sample_neighbors_homo_csc_format(): ...@@ -2428,9 +1588,7 @@ def test_sample_neighbors_homo_csc_format():
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors( subgraph = graph.sample_neighbors(nodes, fanouts=torch.LongTensor([2]))
nodes, fanouts=torch.LongTensor([2]), output_cscformat=True
)
# Verify in subgraph. # Verify in subgraph.
sampled_indptr_num = subgraph.sampled_csc.indptr.size(0) sampled_indptr_num = subgraph.sampled_csc.indptr.size(0)
...@@ -2447,7 +1605,7 @@ def test_sample_neighbors_homo_csc_format(): ...@@ -2447,7 +1605,7 @@ def test_sample_neighbors_homo_csc_format():
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero_csc_format(labor): def test_sample_neighbors_hetero(labor):
"""Original graph in COO: """Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2] "n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0] "n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
...@@ -2482,7 +1640,7 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -2482,7 +1640,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1]) fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=True) subgraph = sampler(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
expected_sampled_csc = { expected_sampled_csc = {
...@@ -2507,7 +1665,7 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -2507,7 +1665,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1]) fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=True) subgraph = sampler(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
expected_sampled_csc = { expected_sampled_csc = {
...@@ -2550,7 +1708,7 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -2550,7 +1708,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
], ],
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_fanouts_csc_format( def test_sample_neighbors_fanouts(
fanouts, expected_sampled_num1, expected_sampled_num2, labor fanouts, expected_sampled_num1, expected_sampled_num2, labor
): ):
"""Original graph in COO: """Original graph in COO:
...@@ -2586,7 +1744,7 @@ def test_sample_neighbors_fanouts_csc_format( ...@@ -2586,7 +1744,7 @@ def test_sample_neighbors_fanouts_csc_format(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.LongTensor(fanouts) fanouts = torch.LongTensor(fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=True) subgraph = sampler(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
assert ( assert (
...@@ -2611,7 +1769,7 @@ def test_sample_neighbors_fanouts_csc_format( ...@@ -2611,7 +1769,7 @@ def test_sample_neighbors_fanouts_csc_format(
"replace, expected_sampled_num1, expected_sampled_num2", "replace, expected_sampled_num1, expected_sampled_num2",
[(False, 2, 2), (True, 4, 4)], [(False, 2, 2), (True, 4, 4)],
) )
def test_sample_neighbors_replace_csc_format( def test_sample_neighbors_replace(
replace, expected_sampled_num1, expected_sampled_num2 replace, expected_sampled_num1, expected_sampled_num2
): ):
"""Original graph in COO: """Original graph in COO:
...@@ -2646,7 +1804,7 @@ def test_sample_neighbors_replace_csc_format( ...@@ -2646,7 +1804,7 @@ def test_sample_neighbors_replace_csc_format(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
subgraph = graph.sample_neighbors( subgraph = graph.sample_neighbors(
nodes, torch.LongTensor([4]), replace=replace, output_cscformat=True nodes, torch.LongTensor([4]), replace=replace
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -2667,7 +1825,7 @@ def test_sample_neighbors_replace_csc_format( ...@@ -2667,7 +1825,7 @@ def test_sample_neighbors_replace_csc_format(
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_homo_csc_format(labor): def test_sample_neighbors_return_eids_homo(labor):
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 1 0 1 0 1
1 0 1 1 0 1 0 1 1 0
...@@ -2693,9 +1851,7 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor): ...@@ -2693,9 +1851,7 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor):
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler( subgraph = sampler(nodes, fanouts=torch.LongTensor([-1]))
nodes, fanouts=torch.LongTensor([-1]), output_cscformat=True
)
# Verify in subgraph. # Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][ expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
...@@ -2711,7 +1867,7 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor): ...@@ -2711,7 +1867,7 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor):
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_hetero_csc_format(labor): def test_sample_neighbors_return_eids_hetero(labor):
""" """
Original graph in COO: Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2] "n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
...@@ -2751,7 +1907,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor): ...@@ -2751,7 +1907,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1]) fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=True) subgraph = sampler(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
expected_reverse_edge_ids = { expected_reverse_edge_ids = {
...@@ -2773,7 +1929,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor): ...@@ -2773,7 +1929,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
@pytest.mark.parametrize("replace", [True, False]) @pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("probs_name", ["weight", "mask"]) @pytest.mark.parametrize("probs_name", ["weight", "mask"])
def test_sample_neighbors_probs_csc_format(replace, labor, probs_name): def test_sample_neighbors_probs(replace, labor, probs_name):
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 1 0 1 0 1
1 0 1 1 0 1 0 1 1 0
...@@ -2809,7 +1965,6 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name): ...@@ -2809,7 +1965,6 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
fanouts=torch.tensor([2]), fanouts=torch.tensor([2]),
replace=replace, replace=replace,
probs_name=probs_name, probs_name=probs_name,
output_cscformat=True,
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -2834,7 +1989,7 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name): ...@@ -2834,7 +1989,7 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
torch.zeros(12, dtype=torch.bool), torch.zeros(12, dtype=torch.bool),
], ],
) )
def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask): def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
# Initialize data. # Initialize data.
total_num_nodes = 5 total_num_nodes = 5
total_num_edges = 12 total_num_edges = 12
...@@ -2858,7 +2013,6 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask): ...@@ -2858,7 +2013,6 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
fanouts=torch.tensor([5]), fanouts=torch.tensor([5]),
replace=replace, replace=replace,
probs_name="probs_or_mask", probs_name="probs_or_mask",
output_cscformat=True,
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -2890,9 +2044,7 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask): ...@@ -2890,9 +2044,7 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
([-1], "none"), ([-1], "none"),
], ],
) )
def test_sample_neighbors_homo_pick_number_csc_format( def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
fanouts, replace, labor, probs_name
):
"""Original graph in COO: """Original graph in COO:
1 1 1 1 1 1 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
...@@ -2930,7 +2082,6 @@ def test_sample_neighbors_homo_pick_number_csc_format( ...@@ -2930,7 +2082,6 @@ def test_sample_neighbors_homo_pick_number_csc_format(
fanouts=torch.LongTensor(fanouts), fanouts=torch.LongTensor(fanouts),
replace=replace, replace=replace,
probs_name=probs_name if probs_name != "none" else None, probs_name=probs_name if probs_name != "none" else None,
output_cscformat=True,
) )
sampled_num = subgraph.sampled_csc.indices.size(0) sampled_num = subgraph.sampled_csc.indices.size(0)
assert subgraph.sampled_csc.indptr.size(0) == 3 assert subgraph.sampled_csc.indptr.size(0) == 3
...@@ -2973,7 +2124,7 @@ def test_sample_neighbors_homo_pick_number_csc_format( ...@@ -2973,7 +2124,7 @@ def test_sample_neighbors_homo_pick_number_csc_format(
([2, -1, 4], "none"), ([2, -1, 4], "none"),
], ],
) )
def test_sample_neighbors_hetero_pick_number_csc_format( def test_sample_neighbors_hetero_pick_number(
fanouts, replace, labor, probs_name fanouts, replace, labor, probs_name
): ):
# Initialize data. # Initialize data.
...@@ -3022,7 +2173,6 @@ def test_sample_neighbors_hetero_pick_number_csc_format( ...@@ -3022,7 +2173,6 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
fanouts=torch.LongTensor(fanouts), fanouts=torch.LongTensor(fanouts),
replace=replace, replace=replace,
probs_name=probs_name if probs_name != "none" else None, probs_name=probs_name if probs_name != "none" else None,
output_cscformat=True,
) )
print(subgraph) print(subgraph)
if probs_name == "none": if probs_name == "none":
......
...@@ -1143,9 +1143,8 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -1143,9 +1143,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
subgraph = fused_csc_sampling_graph.sample_neighbors( subgraph = fused_csc_sampling_graph.sample_neighbors(
torch.arange(num_samples), torch.arange(num_samples),
torch.tensor([fanout]), torch.tensor([fanout]),
output_cscformat=False,
) )
assert len(subgraph.sampled_csc[0]) <= num_samples assert len(subgraph.sampled_csc.indices) <= num_samples
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified. # All metadata fields are specified.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment