"vscode:/vscode.git/clone" did not exist on "46d7b1d9723a64eecc676ef2c629f8c0226c1b68"
Unverified Commit fa1ae3b7 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove coo from `FusedCSCSamplingGraph`. (#6849)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent f81e514d
......@@ -12,11 +12,7 @@ from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..sampling_graph import SamplingGraph
from .sampled_subgraph_impl import (
CSCFormatBase,
FusedSampledSubgraphImpl,
SampledSubgraphImpl,
)
from .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl
__all__ = [
......@@ -324,9 +320,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def in_subgraph(
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
# TODO: clean up once the migration is done.
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
) -> SampledSubgraphImpl:
"""Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming
......@@ -344,7 +338,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns
-------
FusedSampledSubgraphImpl
SampledSubgraphImpl
The in subgraph.
Examples
......@@ -389,60 +383,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes)
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(_in_subgraph)
else:
return self._convert_to_sampled_subgraph(_in_subgraph)
def _convert_to_fused_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
):
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'FusedSampledSubgraphImpl'."""
column_num = (
C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
)
column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
column_num
)
row = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
original_edge_ids = C_sampled_subgraph.original_edge_ids
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
original_edge_ids
]
if type_per_edge is None:
# The sampled graph is already a homogeneous graph.
node_pairs = (row, column)
else:
# The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs.
node_pairs = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.node_type_to_id[src_ntype]
dst_ntype_id = self.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id
hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = (
column[mask] - self.node_type_offset[dst_ntype_id]
)
node_pairs[etype] = (hetero_row, hetero_column)
if has_original_eids:
original_hetero_edge_ids[etype] = original_edge_ids[mask]
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
return FusedSampledSubgraphImpl(
sampled_csc=node_pairs, original_edge_ids=original_edge_ids
)
def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
homogeneous_nodes = []
homogeneous_timestamps = []
......@@ -560,9 +502,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
# TODO: clean up once the migration is done.
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -604,7 +544,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns
-------
Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]
SampledSubgraphImpl
The sampled subgraph.
Examples
......@@ -638,9 +578,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, probs_name
)
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _check_sampler_arguments(self, nodes, fanouts, probs_name):
......@@ -753,9 +690,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
# TODO: clean up once the migration is done.
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
......@@ -799,7 +734,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns
-------
FusedSampledSubgraphImpl
SampledSubgraphImpl
The sampled subgraph.
Examples
......@@ -843,10 +778,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
has_original_eids,
probs_name,
)
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def temporal_sample_neighbors(
......@@ -906,7 +837,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns
-------
FusedSampledSubgraphImpl
SampledSubgraphImpl
The sampled subgraph.
"""
if isinstance(nodes, dict):
......
......@@ -683,137 +683,6 @@ def test_multiprocessing():
p.join()
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_node_pairs_homogeneous():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(indptr, indices)
# Extract in subgraph.
nodes = torch.LongTensor([4, 1, 3])
in_subgraph = graph.in_subgraph(nodes, output_cscformat=False)
# Verify in subgraph.
assert torch.equal(
in_subgraph.sampled_csc[0], torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
)
assert torch.equal(
in_subgraph.sampled_csc[1], torch.LongTensor([4, 4, 4, 1, 1, 3, 3])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([9, 10, 11, 3, 4, 7, 8])
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_node_pairs_heterogeneous():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
node_type_0: [0, 1]
node_type_1: [2, 3, 4]
edge_type_0: node_type_0 -> node_type_0
edge_type_1: node_type_0 -> node_type_1
edge_type_2: node_type_1 -> node_type_0
edge_type_3: node_type_1 -> node_type_1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
ntypes = {
"N0": 0,
"N1": 1,
}
etypes = {
"N0:R0:N0": 0,
"N0:R1:N1": 1,
"N1:R2:N0": 2,
"N1:R3:N1": 3,
}
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
node_type_offset = torch.LongTensor([0, 2, 5])
type_per_edge = torch.LongTensor([0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes))
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Extract in subgraph.
nodes = {
"N0": torch.LongTensor([1]),
"N1": torch.LongTensor([2, 1]),
}
in_subgraph = graph.in_subgraph(nodes, output_cscformat=False)
# Verify in subgraph.
assert torch.equal(
in_subgraph.sampled_csc["N0:R0:N0"][0], torch.LongTensor([])
)
assert torch.equal(
in_subgraph.sampled_csc["N0:R0:N0"][1], torch.LongTensor([])
)
assert torch.equal(
in_subgraph.sampled_csc["N0:R1:N1"][0], torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.sampled_csc["N0:R1:N1"][1], torch.LongTensor([2, 1])
)
assert torch.equal(
in_subgraph.sampled_csc["N1:R2:N0"][0], torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.sampled_csc["N1:R2:N0"][1], torch.LongTensor([1, 1])
)
assert torch.equal(
in_subgraph.sampled_csc["N1:R3:N1"][0], torch.LongTensor([1, 2, 0])
)
assert torch.equal(
in_subgraph.sampled_csc["N1:R3:N1"][1], torch.LongTensor([2, 2, 1])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 9, 10, 11, 7, 8])
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
......@@ -839,7 +708,7 @@ def test_in_subgraph_homo():
# Extract in subgraph.
nodes = torch.LongTensor([4, 1, 3])
in_subgraph = graph.in_subgraph(nodes, output_cscformat=True)
in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph.
assert torch.equal(
......@@ -911,7 +780,7 @@ def test_in_subgraph_hetero():
"N0": torch.LongTensor([1]),
"N1": torch.LongTensor([2, 1]),
}
in_subgraph = graph.in_subgraph(nodes, output_cscformat=True)
in_subgraph = graph.in_subgraph(nodes)
# Verify in subgraph.
assert torch.equal(
......@@ -945,64 +814,6 @@ def test_in_subgraph_hetero():
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)
indices = torch.tensor(
[0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype
)
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert len(indptr) == total_num_nodes + 1
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(indptr, indices)
# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([2])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = torch.tensor(
[1, 3, 4],
dtype=(torch.int64 if indices_dtype == torch.int32 else torch.int32),
)
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype)
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
sampled_num = subgraph.sampled_csc[0].size(0)
assert sampled_num == 6
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
......@@ -1125,10 +936,14 @@ def test_temporal_sample_neighbors_homo(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("use_node_timestamp", [False, True])
@pytest.mark.parametrize("use_edge_timestamp", [False, True])
def test_temporal_sample_neighbors_hetero(
indptr_dtype, indices_dtype, replace, use_node_timestamp, use_edge_timestamp
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
......@@ -1141,6 +956,8 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
ntypes_to_offset = {"n1": 0, "n2": 2}
total_num_nodes = 5
total_num_edges = 9
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
......@@ -1159,561 +976,97 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
edge_type_to_id=etypes,
)
# Sample on both node types.
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = {
"n1": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
),
"n2": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
),
}
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts, output_cscformat=False)
# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([-1, -1])
sampler = graph.temporal_sample_neighbors
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = {
seeds = {
"n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype),
}
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
expected_node_pairs = {
"n1:e1:n2": (
torch.LongTensor([0, 1]),
torch.LongTensor([0, 0]),
),
"n2:e2:n1": (
torch.LongTensor([0, 2]),
torch.LongTensor([0, 0]),
),
}
assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.sampled_csc[etype][0], pairs[0])
assert torch.equal(subgraph.sampled_csc[etype][1], pairs[1])
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
# Sample on single node type.
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes = {
"n1": torch.tensor(
[0],
dtype=(
torch.int64 if indices_dtype == torch.int32 else torch.int32
),
)
per_etype_destination_nodes = {
"n1:e1:n2": torch.tensor([1], dtype=indices_dtype),
"n2:e2:n1": torch.tensor([0], dtype=indices_dtype),
}
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts, output_cscformat=False)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = {"n1": torch.tensor([0], dtype=indices_dtype)}
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
expected_node_pairs = {
"n2:e2:n1": (
torch.tensor([0, 2], dtype=indices_dtype),
torch.tensor([0, 0], dtype=indices_dtype),
),
"n1:e1:n2": (
torch.tensor([], dtype=indices_dtype),
torch.tensor([], dtype=indices_dtype),
),
seed_timestamp = {
"n1": torch.randint(0, 100, (1,), dtype=torch.int64),
"n2": torch.randint(0, 100, (1,), dtype=torch.int64),
}
assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.sampled_csc[etype][0], pairs[0])
assert torch.equal(subgraph.sampled_csc[etype][1], pairs[1])
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
if use_node_timestamp:
node_timestamp = torch.randint(
0, 100, (total_num_nodes,), dtype=torch.int64
)
graph.node_attributes = {"timestamp": node_timestamp}
if use_edge_timestamp:
edge_timestamp = torch.randint(
0, 100, (total_num_edges,), dtype=torch.int64
)
graph.edge_attributes = {"timestamp": edge_timestamp}
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("use_node_timestamp", [False, True])
@pytest.mark.parametrize("use_edge_timestamp", [False, True])
def test_temporal_sample_neighbors_hetero(
indptr_dtype, indices_dtype, replace, use_node_timestamp, use_edge_timestamp
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
ntypes_to_offset = {"n1": 0, "n2": 2}
total_num_nodes = 5
total_num_edges = 9
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([-1, -1])
sampler = graph.temporal_sample_neighbors
seeds = {
"n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype),
}
per_etype_destination_nodes = {
"n1:e1:n2": torch.tensor([1], dtype=indices_dtype),
"n2:e2:n1": torch.tensor([0], dtype=indices_dtype),
}
seed_timestamp = {
"n1": torch.randint(0, 100, (1,), dtype=torch.int64),
"n2": torch.randint(0, 100, (1,), dtype=torch.int64),
}
if use_node_timestamp:
node_timestamp = torch.randint(
0, 100, (total_num_nodes,), dtype=torch.int64
)
graph.node_attributes = {"timestamp": node_timestamp}
if use_edge_timestamp:
edge_timestamp = torch.randint(
0, 100, (total_num_edges,), dtype=torch.int64
)
graph.edge_attributes = {"timestamp": edge_timestamp}
subgraph, neighbors_timestamp = sampler(
seeds,
seed_timestamp,
fanouts,
replace=replace,
node_timestamp_attr_name="timestamp" if use_node_timestamp else None,
edge_timestamp_attr_name="timestamp" if use_edge_timestamp else None,
)
subgraph, neighbors_timestamp = sampler(
seeds,
seed_timestamp,
fanouts,
replace=replace,
node_timestamp_attr_name="timestamp" if use_node_timestamp else None,
edge_timestamp_attr_name="timestamp" if use_edge_timestamp else None,
)
def _to_homo():
ret_seeds, ret_timestamps = [], []
for ntype, nodes in seeds.items():
ntype_id = ntypes[ntype]
offset = node_type_offset[ntype_id]
ret_seeds.append(nodes + offset)
ret_timestamps.append(seed_timestamp[ntype])
return torch.cat(ret_seeds), torch.cat(ret_timestamps)
homo_seeds, homo_seed_timestamp = _to_homo()
def _get_available_neighbors():
available_neighbors = []
for i, seed in enumerate(homo_seeds):
neighbors = []
start = indptr[seed].item()
end = indptr[seed + 1].item()
for j in range(start, end):
neighbor = indices[j].item()
if (
use_node_timestamp
and (
node_timestamp[neighbor] > homo_seed_timestamp[i]
).item()
):
continue
if (
use_edge_timestamp
and (edge_timestamp[j] > homo_seed_timestamp[i]).item()
):
continue
neighbors.append(neighbor)
available_neighbors.append(neighbors)
return available_neighbors
available_neighbors = _get_available_neighbors()
sampled_count = [0] * homo_seeds.numel()
sampled_neighbors = [[] for _ in range(homo_seeds.numel())]
for etype, csc in subgraph.sampled_csc.items():
stype, _, _ = etype_str_to_tuple(etype)
ntype_offset = ntypes_to_offset[stype]
dest_nodes = per_etype_destination_nodes[etype]
for i in range(dest_nodes.numel()):
l = csc.indptr[i]
r = csc.indptr[i + 1]
seed_offset = dest_nodes[i].item()
sampled_neighbors[seed_offset].extend(
(csc.indices[l:r] + ntype_offset).tolist()
)
sampled_count[seed_offset] += r - l
for i, count in enumerate(sampled_count):
assert count == len(available_neighbors[i])
assert set(sampled_neighbors[i]).issubset(set(available_neighbors[i]))
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"fanouts, expected_sampled_num1, expected_sampled_num2",
[
([0], 0, 0),
([1], 1, 1),
([2], 2, 2),
([4], 2, 2),
([-1], 2, 2),
([0, 0], 0, 0),
([1, 0], 1, 0),
([0, 1], 0, 1),
([1, 1], 1, 1),
([2, 1], 2, 1),
([-1, -1], 2, 2),
],
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_fanouts(
fanouts, expected_sampled_num1, expected_sampled_num2, labor
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.LongTensor(fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
assert (
expected_sampled_num1 == 0
or subgraph.sampled_csc["n1:e1:n2"][0].numel() == expected_sampled_num1
)
assert (
expected_sampled_num2 == 0
or subgraph.sampled_csc["n2:e2:n1"][0].numel() == expected_sampled_num2
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"replace, expected_sampled_num1, expected_sampled_num2",
[(False, 2, 2), (True, 4, 4)],
)
def test_sample_neighbors_replace(
replace, expected_sampled_num1, expected_sampled_num2
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
subgraph = graph.sample_neighbors(
nodes, torch.LongTensor([4]), replace=replace, output_cscformat=False
)
# Verify in subgraph.
assert subgraph.sampled_csc["n1:e1:n2"][0].numel() == expected_sampled_num1
assert subgraph.sampled_csc["n2:e2:n1"][0].numel() == expected_sampled_num2
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_homo(labor):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Add edge id mapping from CSC graph -> original graph.
edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)}
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(
nodes, fanouts=torch.LongTensor([-1]), output_cscformat=False
)
# Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
torch.tensor([3, 4, 7, 8, 9, 10, 11])
]
assert torch.equal(expected_reverse_edge_ids, subgraph.original_edge_ids)
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_hetero(labor):
"""
Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
edge_attributes = {
gb.ORIGINAL_EDGE_ID: torch.cat([torch.randperm(4), torch.randperm(5)])
}
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
edge_attributes=edge_attributes,
)
# Sample on both node types.
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
expected_reverse_edge_ids = {
"n2:e2:n1": edge_attributes[gb.ORIGINAL_EDGE_ID][torch.tensor([0, 1])],
"n1:e1:n2": edge_attributes[gb.ORIGINAL_EDGE_ID][torch.tensor([4, 5])],
}
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
for etype in etypes.keys():
assert torch.equal(
subgraph.original_edge_ids[etype], expected_reverse_edge_ids[etype]
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("probs_name", ["weight", "mask"])
def test_sample_neighbors_probs(replace, labor, probs_name):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
edge_attributes = {
"weight": torch.FloatTensor(
[2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5, 0.4, 1.2]
),
"mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]),
}
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(
nodes,
fanouts=torch.tensor([2]),
replace=replace,
probs_name=probs_name,
output_cscformat=False,
)
# Verify in subgraph.
sampled_num = subgraph.sampled_csc[0].size(0)
if replace:
assert sampled_num == 6
else:
assert sampled_num == 4
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"probs_or_mask",
[
torch.zeros(12, dtype=torch.float32),
torch.zeros(12, dtype=torch.bool),
],
)
def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
# Initialize data.
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
ret_seeds, ret_timestamps = [], []
for ntype, nodes in seeds.items():
ntype_id = ntypes[ntype]
offset = node_type_offset[ntype_id]
ret_seeds.append(nodes + offset)
ret_timestamps.append(seed_timestamp[ntype])
return torch.cat(ret_seeds), torch.cat(ret_timestamps)
edge_attributes = {"probs_or_mask": probs_or_mask}
homo_seeds, homo_seed_timestamp = _to_homo()
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
def _get_available_neighbors():
available_neighbors = []
for i, seed in enumerate(homo_seeds):
neighbors = []
start = indptr[seed].item()
end = indptr[seed + 1].item()
for j in range(start, end):
neighbor = indices[j].item()
if (
use_node_timestamp
and (
node_timestamp[neighbor] > homo_seed_timestamp[i]
).item()
):
continue
if (
use_edge_timestamp
and (edge_timestamp[j] > homo_seed_timestamp[i]).item()
):
continue
neighbors.append(neighbor)
available_neighbors.append(neighbors)
return available_neighbors
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(
nodes,
fanouts=torch.tensor([5]),
replace=replace,
probs_name="probs_or_mask",
output_cscformat=False,
available_neighbors = _get_available_neighbors()
sampled_count = [0] * homo_seeds.numel()
sampled_neighbors = [[] for _ in range(homo_seeds.numel())]
for etype, csc in subgraph.sampled_csc.items():
stype, _, _ = etype_str_to_tuple(etype)
ntype_offset = ntypes_to_offset[stype]
dest_nodes = per_etype_destination_nodes[etype]
for i in range(dest_nodes.numel()):
l = csc.indptr[i]
r = csc.indptr[i + 1]
seed_offset = dest_nodes[i].item()
sampled_neighbors[seed_offset].extend(
(csc.indices[l:r] + ntype_offset).tolist()
)
sampled_count[seed_offset] += r - l
# Verify in subgraph.
sampled_num = subgraph.sampled_csc[0].size(0)
assert sampled_num == 0
for i, count in enumerate(sampled_count):
assert count == len(available_neighbors[i])
assert set(sampled_neighbors[i]).issubset(set(available_neighbors[i]))
def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
......@@ -2158,199 +1511,6 @@ def test_from_dglgraph_heterogeneous():
}
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"fanouts, probs_name",
[
([2], "mask"),
([3], "mask"),
([4], "mask"),
([-1], "mask"),
([7], "mask"),
([3], "all"),
([-1], "all"),
([7], "all"),
([3], "zero"),
([-1], "zero"),
([3], "none"),
([-1], "none"),
],
)
def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
"""Original graph in COO:
1 1 1 1 1 1
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
0 0 0 0 0 0
"""
# Initialize data.
total_num_nodes = 6
total_num_edges = 6
indptr = torch.LongTensor([0, 6, 6, 6, 6, 6, 6])
indices = torch.LongTensor([0, 1, 2, 3, 4, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
edge_attributes = {
"mask": torch.BoolTensor([1, 0, 0, 1, 0, 1]),
"all": torch.BoolTensor([1, 1, 1, 1, 1, 1]),
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0]),
}
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# Make sure no exception will be thrown.
subgraph = sampler(
nodes,
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
output_cscformat=False,
)
sampled_num = subgraph.sampled_csc[0].size(0)
# Verify in subgraph.
if probs_name == "mask":
if fanouts[0] == -1:
assert sampled_num == 3
else:
if replace:
assert sampled_num == fanouts[0]
else:
assert sampled_num == min(fanouts[0], 3)
elif probs_name == "zero":
assert sampled_num == 0
else:
if fanouts[0] == -1:
assert sampled_num == 6
else:
if replace:
assert sampled_num == fanouts[0]
else:
assert sampled_num == min(fanouts[0], 6)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"fanouts, probs_name",
[
([-1, -1, -1], "mask"),
([1, 1, 1], "mask"),
([2, 2, 2], "mask"),
([3, 3, 3], "mask"),
([4, 4, 4], "mask"),
([-1, 1, 3], "none"),
([2, -1, 4], "none"),
],
)
def test_sample_neighbors_hetero_pick_number(
fanouts, replace, labor, probs_name
):
# Initialize data.
total_num_nodes = 10
total_num_edges = 9
ntypes = {"N0": 0, "N1": 1, "N2": 2, "N3": 3}
etypes = {
"N0:R0:N1": 0,
"N0:R1:N2": 1,
"N0:R2:N3": 2,
}
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
type_per_edge = torch.LongTensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes))
edge_attributes = {
"mask": torch.BoolTensor([1, 1, 0, 1, 1, 1, 0, 0, 0]),
"all": torch.BoolTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),
}
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
edge_attributes=edge_attributes,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
# Make sure no exception will be thrown.
subgraph = sampler(
nodes,
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
output_cscformat=False,
)
if probs_name == "none":
for etype, pairs in subgraph.sampled_csc.items():
fanout = fanouts[etypes[etype]]
if fanout == -1:
assert pairs[0].size(0) == 3
else:
if replace:
assert pairs[0].size(0) == fanout
else:
assert pairs[0].size(0) == min(fanout, 3)
else:
fanout = fanouts[0] # Here fanout is the same for all etypes.
for etype, pairs in subgraph.sampled_csc.items():
if etypes[etype] == 0:
# Etype 0: 2 valid neighbors.
if fanout == -1:
assert pairs[0].size(0) == 2
else:
if replace:
assert pairs[0].size(0) == fanout
else:
assert pairs[0].size(0) == min(fanout, 2)
elif etypes[etype] == 1:
# Etype 1: 3 valid neighbors.
if fanout == -1:
assert pairs[0].size(0) == 3
else:
if replace:
assert pairs[0].size(0) == fanout
else:
assert pairs[0].size(0) == min(fanout, 3)
else:
# Etype 2: 0 valid neighbors.
assert pairs[0].size(0) == 0
@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
......@@ -2408,7 +1568,7 @@ def test_csc_sampling_graph_to_device():
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_sample_neighbors_homo_csc_format():
def test_sample_neighbors_homo():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -2428,9 +1588,7 @@ def test_sample_neighbors_homo_csc_format():
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(
nodes, fanouts=torch.LongTensor([2]), output_cscformat=True
)
subgraph = graph.sample_neighbors(nodes, fanouts=torch.LongTensor([2]))
# Verify in subgraph.
sampled_indptr_num = subgraph.sampled_csc.indptr.size(0)
......@@ -2447,7 +1605,7 @@ def test_sample_neighbors_homo_csc_format():
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero_csc_format(labor):
def test_sample_neighbors_hetero(labor):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
......@@ -2482,7 +1640,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=True)
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
expected_sampled_csc = {
......@@ -2507,7 +1665,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=True)
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
expected_sampled_csc = {
......@@ -2550,7 +1708,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
],
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_fanouts_csc_format(
def test_sample_neighbors_fanouts(
fanouts, expected_sampled_num1, expected_sampled_num2, labor
):
"""Original graph in COO:
......@@ -2586,7 +1744,7 @@ def test_sample_neighbors_fanouts_csc_format(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.LongTensor(fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=True)
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
assert (
......@@ -2611,7 +1769,7 @@ def test_sample_neighbors_fanouts_csc_format(
"replace, expected_sampled_num1, expected_sampled_num2",
[(False, 2, 2), (True, 4, 4)],
)
def test_sample_neighbors_replace_csc_format(
def test_sample_neighbors_replace(
replace, expected_sampled_num1, expected_sampled_num2
):
"""Original graph in COO:
......@@ -2646,7 +1804,7 @@ def test_sample_neighbors_replace_csc_format(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
subgraph = graph.sample_neighbors(
nodes, torch.LongTensor([4]), replace=replace, output_cscformat=True
nodes, torch.LongTensor([4]), replace=replace
)
# Verify in subgraph.
......@@ -2667,7 +1825,7 @@ def test_sample_neighbors_replace_csc_format(
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_homo_csc_format(labor):
def test_sample_neighbors_return_eids_homo(labor):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -2693,9 +1851,7 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor):
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(
nodes, fanouts=torch.LongTensor([-1]), output_cscformat=True
)
subgraph = sampler(nodes, fanouts=torch.LongTensor([-1]))
# Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
......@@ -2711,7 +1867,7 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor):
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_hetero_csc_format(labor):
def test_sample_neighbors_return_eids_hetero(labor):
"""
Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
......@@ -2751,7 +1907,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, output_cscformat=True)
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
expected_reverse_edge_ids = {
......@@ -2773,7 +1929,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("probs_name", ["weight", "mask"])
def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
def test_sample_neighbors_probs(replace, labor, probs_name):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -2809,7 +1965,6 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
fanouts=torch.tensor([2]),
replace=replace,
probs_name=probs_name,
output_cscformat=True,
)
# Verify in subgraph.
......@@ -2834,7 +1989,7 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
torch.zeros(12, dtype=torch.bool),
],
)
def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
......@@ -2858,7 +2013,6 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
fanouts=torch.tensor([5]),
replace=replace,
probs_name="probs_or_mask",
output_cscformat=True,
)
# Verify in subgraph.
......@@ -2890,9 +2044,7 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
([-1], "none"),
],
)
def test_sample_neighbors_homo_pick_number_csc_format(
fanouts, replace, labor, probs_name
):
def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
"""Original graph in COO:
1 1 1 1 1 1
0 0 0 0 0 0
......@@ -2930,7 +2082,6 @@ def test_sample_neighbors_homo_pick_number_csc_format(
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
output_cscformat=True,
)
sampled_num = subgraph.sampled_csc.indices.size(0)
assert subgraph.sampled_csc.indptr.size(0) == 3
......@@ -2973,7 +2124,7 @@ def test_sample_neighbors_homo_pick_number_csc_format(
([2, -1, 4], "none"),
],
)
def test_sample_neighbors_hetero_pick_number_csc_format(
def test_sample_neighbors_hetero_pick_number(
fanouts, replace, labor, probs_name
):
# Initialize data.
......@@ -3022,7 +2173,6 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
output_cscformat=True,
)
print(subgraph)
if probs_name == "none":
......
......@@ -1143,9 +1143,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
subgraph = fused_csc_sampling_graph.sample_neighbors(
torch.arange(num_samples),
torch.tensor([fanout]),
output_cscformat=False,
)
assert len(subgraph.sampled_csc[0]) <= num_samples
assert len(subgraph.sampled_csc.indices) <= num_samples
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment