Unverified Commit 898af658 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Add temporal sampling unittests. (#6795)

parent 1f9ae668
......@@ -12,6 +12,8 @@ import dgl.graphbolt as gb
import pytest
import torch
import torch.multiprocessing as mp
from dgl.graphbolt.base import etype_str_to_tuple
from scipy import sparse as spsp
from .. import gb_test_utils as gbt
......@@ -1001,6 +1003,124 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
assert subgraph.original_edge_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("use_node_timestamp", [False, True])
@pytest.mark.parametrize("use_edge_timestamp", [False, True])
def test_temporal_sample_neighbors_homo(
indptr_dtype, indices_dtype, replace, use_node_timestamp, use_edge_timestamp
):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)
indices = torch.tensor(
[0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype
)
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert len(indptr) == total_num_nodes + 1
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(indptr, indices)
# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([2])
sampler = graph.temporal_sample_neighbors
seed_list = [1, 3, 4]
seed_timestamp = torch.randint(0, 100, (len(seed_list),), dtype=torch.int64)
if use_node_timestamp:
node_timestamp = torch.randint(
0, 100, (total_num_nodes,), dtype=torch.int64
)
graph.node_attributes = {"timestamp": node_timestamp}
if use_edge_timestamp:
edge_timestamp = torch.randint(
0, 100, (total_num_edges,), dtype=torch.int64
)
graph.edge_attributes = {"timestamp": edge_timestamp}
# Sample with nodes in mismatched dtype with graph's indices.
nodes = torch.tensor(
seed_list,
dtype=(torch.int64 if indices_dtype == torch.int32 else torch.int32),
)
with pytest.raises(
AssertionError,
match=re.escape(
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(
nodes,
seed_timestamp,
fanouts,
replace=replace,
node_timestamp_attr_name="timestamp"
if use_node_timestamp
else None,
edge_timestamp_attr_name="timestamp"
if use_edge_timestamp
else None,
)
def _get_available_neighbors():
available_neighbors = []
for i, seed in enumerate(seed_list):
neighbors = []
start = indptr[seed].item()
end = indptr[seed + 1].item()
for j in range(start, end):
neighbor = indices[j].item()
if (
use_node_timestamp
and (node_timestamp[neighbor] > seed_timestamp[i]).item()
):
continue
if (
use_edge_timestamp
and (edge_timestamp[j] > seed_timestamp[i]).item()
):
continue
neighbors.append(neighbor)
available_neighbors.append(neighbors)
return available_neighbors
nodes = torch.tensor(seed_list, dtype=indices_dtype)
subgraph, neighbors_timestamp = sampler(
nodes,
seed_timestamp,
fanouts,
replace=replace,
node_timestamp_attr_name="timestamp" if use_node_timestamp else None,
edge_timestamp_attr_name="timestamp" if use_edge_timestamp else None,
)
sampled_count = torch.diff(subgraph.node_pairs.indptr).tolist()
available_neighbors = _get_available_neighbors()
for i, count in enumerate(sampled_count):
if not replace:
expect_count = min(fanouts[0], len(available_neighbors[i]))
else:
expect_count = fanouts[0] if len(available_neighbors[i]) > 0 else 0
assert count == expect_count
sampled_neighbors = torch.split(subgraph.node_pairs.indices, sampled_count)
for i, neighbors in enumerate(sampled_neighbors):
assert set(neighbors.tolist()).issubset(set(available_neighbors[i]))
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
......@@ -1137,6 +1257,143 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
assert subgraph.original_edge_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("use_node_timestamp", [False, True])
@pytest.mark.parametrize("use_edge_timestamp", [False, True])
def test_temporal_sample_neighbors_hetero(
indptr_dtype, indices_dtype, replace, use_node_timestamp, use_edge_timestamp
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
0 1 0 0 0
1 0 0 0 0
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
ntypes_to_offset = {"n1": 0, "n2": 2}
total_num_nodes = 5
total_num_edges = 9
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([-1, -1])
sampler = graph.temporal_sample_neighbors
seeds = {
"n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype),
}
per_etype_destination_nodes = {
"n1:e1:n2": torch.tensor([1], dtype=indices_dtype),
"n2:e2:n1": torch.tensor([0], dtype=indices_dtype),
}
seed_timestamp = {
"n1": torch.randint(0, 100, (1,), dtype=torch.int64),
"n2": torch.randint(0, 100, (1,), dtype=torch.int64),
}
if use_node_timestamp:
node_timestamp = torch.randint(
0, 100, (total_num_nodes,), dtype=torch.int64
)
graph.node_attributes = {"timestamp": node_timestamp}
if use_edge_timestamp:
edge_timestamp = torch.randint(
0, 100, (total_num_edges,), dtype=torch.int64
)
graph.edge_attributes = {"timestamp": edge_timestamp}
subgraph, neighbors_timestamp = sampler(
seeds,
seed_timestamp,
fanouts,
replace=replace,
node_timestamp_attr_name="timestamp" if use_node_timestamp else None,
edge_timestamp_attr_name="timestamp" if use_edge_timestamp else None,
)
def _to_homo():
ret_seeds, ret_timestamps = [], []
for ntype, nodes in seeds.items():
ntype_id = ntypes[ntype]
offset = node_type_offset[ntype_id]
ret_seeds.append(nodes + offset)
ret_timestamps.append(seed_timestamp[ntype])
return torch.cat(ret_seeds), torch.cat(ret_timestamps)
homo_seeds, homo_seed_timestamp = _to_homo()
def _get_available_neighbors():
available_neighbors = []
for i, seed in enumerate(homo_seeds):
neighbors = []
start = indptr[seed].item()
end = indptr[seed + 1].item()
for j in range(start, end):
neighbor = indices[j].item()
if (
use_node_timestamp
and (
node_timestamp[neighbor] > homo_seed_timestamp[i]
).item()
):
continue
if (
use_edge_timestamp
and (edge_timestamp[j] > homo_seed_timestamp[i]).item()
):
continue
neighbors.append(neighbor)
available_neighbors.append(neighbors)
return available_neighbors
available_neighbors = _get_available_neighbors()
sampled_count = [0] * homo_seeds.numel()
sampled_neighbors = [[] for _ in range(homo_seeds.numel())]
for etype, csc in subgraph.node_pairs.items():
stype, _, _ = etype_str_to_tuple(etype)
ntype_offset = ntypes_to_offset[stype]
dest_nodes = per_etype_destination_nodes[etype]
for i in range(dest_nodes.numel()):
l = csc.indptr[i]
r = csc.indptr[i + 1]
seed_offset = dest_nodes[i].item()
sampled_neighbors[seed_offset].extend(
(csc.indices[l:r] + ntype_offset).tolist()
)
sampled_count[seed_offset] += r - l
for i, count in enumerate(sampled_count):
assert count == len(available_neighbors[i])
assert set(sampled_neighbors[i]).issubset(set(available_neighbors[i]))
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment