Unverified Commit 4df2e399 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add not_deduplication compact. (#6579)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 308f8ca3
......@@ -20,6 +20,7 @@ from .sampled_subgraph import *
from .subgraph_sampler import *
from .utils import (
add_reverse_edges,
compact_csc_format,
exclude_seed_edges,
unique_and_compact,
unique_and_compact_node_pairs,
......
......@@ -4,8 +4,8 @@ import torch
from torch.utils.data import functional_datapipe
from ..subgraph_sampler import SubgraphSampler
from ..utils import unique_and_compact_node_pairs
from .sampled_subgraph_impl import FusedSampledSubgraphImpl
from ..utils import compact_csc_format, unique_and_compact_node_pairs
from .sampled_subgraph_impl import FusedSampledSubgraphImpl, SampledSubgraphImpl
__all__ = ["NeighborSampler", "LayerNeighborSampler"]
......@@ -122,14 +122,27 @@ class NeighborSampler(SubgraphSampler):
original_row_node_ids,
compacted_node_pairs,
) = unique_and_compact_node_pairs(subgraph.node_pairs, seeds)
subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
raise RuntimeError("Not implemented yet.")
subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
(
original_row_node_ids,
compacted_csc_format,
) = compact_csc_format(subgraph.node_pairs, seeds)
# [TODO] For node_pairs is defined in SampledSubgraph, which is
# SampledSubgraph's parent class, and it's still inherited by
# other classes, the name cannot be changed currently. This
# part will be cleaned up later.
subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
subgraphs.insert(0, subgraph)
seeds = original_row_node_ids
return seeds, subgraphs
......
"""Utility functions for sampling."""
import copy
from collections import defaultdict
from typing import Dict, List, Tuple, Union
import torch
from ..base import etype_str_to_tuple
from ..base import CSCFormatBase, etype_str_to_tuple
from ..minibatch import MiniBatch
......@@ -267,3 +268,106 @@ def unique_and_compact_node_pairs(
unique_nodes = list(unique_nodes.values())[0]
return unique_nodes, compacted_node_pairs
def compact_csc_format(
csc_formats: Union[CSCFormatBase, Dict[str, CSCFormatBase]],
dst_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
):
"""
Compact csc formats and return original_row_ids (per type).
Parameters
----------
csc_formats: Union[CSCFormatBase, Dict[str, CSCFormatBase]]
CSC formats representing source-destination edges.
- If `csc_formats` is a CSCFormatBase: It means the graph is
homogeneous. Also, indptr and indice in it should be torch.tensor
representing source and destination pairs in csc format. And IDs inside
are homogeneous ids.
- If `csc_formats` is a Dict[str, CSCFormatBase]: The keys
should be edge type and the values should be csc format node pairs.
And IDs inside are heterogeneous ids.
dst_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
Nodes of all destination nodes in the node pairs.
- If `dst_nodes` is a tensor: It means the graph is homogeneous.
- If `dst_nodes` is a dictionary: The keys are node type and the
values are corresponding nodes. And IDs inside are heterogeneous ids.
Returns
-------
Tuple[original_row_node_ids, compacted_csc_formats]
The compacted CSC formats, where node IDs are replaced with mapped node
IDs, and all nodes (per type).
"Compacted CSC formats" indicates that the node IDs in the input node
pairs are replaced with mapped node IDs, where each type of node is
mapped to a contiguous space of IDs ranging from 0 to N.
Examples
--------
>>> import dgl.graphbolt as gb
>>> N1 = torch.LongTecnsor([1, 2, 2])
>>> N2 = torch.LongTensor([5, 6, 5])
>>> csc_formats = {"n2:e2:n1": CSCFormatBase(indptr=torch.tensor([0, 1]),
... indices=torch.tensor([5]))}
>>> dst_nodes = {"n1": N1[:1]}
>>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(
... csc_formats, dst_nodes
... )
>>> print(original_row_node_ids)
{'n1': tensor([1]), 'n2': tensor([5])}
>>> print(compacted_csc_formats)
{"n2:e2:n1": CSCFormatBase(indptr=tensor([0, 1]),
... indices=tensor([0]))}
"""
is_homogeneous = not isinstance(csc_formats, dict)
if is_homogeneous:
if dst_nodes is not None:
assert isinstance(
dst_nodes, torch.Tensor
), "Edge type not supported in homogeneous graph."
assert csc_formats.indptr[-1] == len(
csc_formats.indices
), "The last element of indptr should be the same as the length of indices."
assert len(dst_nodes) + 1 == len(
csc_formats.indptr
), "The seed nodes should correspond to indptr."
offset = dst_nodes.size(0)
original_row_ids = torch.cat((dst_nodes, csc_formats.indices))
compacted_csc_formats = CSCFormatBase(
indptr=csc_formats.indptr,
indices=(torch.arange(0, csc_formats.indices.size(0)) + offset),
)
else:
compacted_csc_formats = {}
original_row_ids = copy.deepcopy(dst_nodes)
for etype, csc_format in csc_formats.items():
assert csc_format.indptr[-1] == len(
csc_format.indices
), "The last element of indptr should be the same as the length of indices."
src_type, _, dst_type = etype_str_to_tuple(etype)
assert len(dst_nodes[dst_type]) + 1 == len(
csc_format.indptr
), "The seed nodes should correspond to indptr."
offset = original_row_ids.get(src_type, torch.tensor([])).size(0)
original_row_ids[src_type] = torch.cat(
(
original_row_ids.get(
src_type,
torch.tensor([], dtype=csc_format.indices.dtype),
),
csc_format.indices,
)
)
compacted_csc_formats[etype] = CSCFormatBase(
indptr=csc_format.indptr,
indices=(
torch.arange(
0,
csc_format.indices.size(0),
dtype=csc_format.indices.dtype,
)
+ offset
),
)
return original_row_ids, compacted_csc_formats
......@@ -182,3 +182,67 @@ def test_incomplete_unique_dst_nodes_():
unique_dst_nodes = torch.arange(150, 200)
with pytest.raises(IndexError):
gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes)
def test_compact_csc_format_hetero():
N1 = torch.randint(0, 50, (30,))
N2 = torch.randint(0, 50, (20,))
N3 = torch.randint(0, 50, (10,))
expected_original_row_ids = {
"n1": N1,
"n2": N2,
"n3": N3,
}
csc_formats = {
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.arange(0, 22, 2),
indices=N1[:20],
),
"n1:e2:n3": gb.CSCFormatBase(
indptr=torch.arange(0, 11),
indices=N1[20:30],
),
"n2:e3:n3": gb.CSCFormatBase(
indptr=torch.arange(0, 11),
indices=N2[10:],
),
}
dst_nodes = {"n2": N2[:10], "n3": N3}
original_row_ids, compacted_csc_formats = gb.compact_csc_format(
csc_formats, dst_nodes
)
for ntype, nodes in original_row_ids.items():
expected_nodes = expected_original_row_ids[ntype]
assert torch.equal(nodes, expected_nodes)
for etype, csc_format in compacted_csc_formats.items():
indptr = csc_format.indptr
indices = csc_format.indices
src_type, _, _ = gb.etype_str_to_tuple(etype)
indices = original_row_ids[src_type][indices]
expected_indptr = csc_formats[etype].indptr
expected_indices = csc_formats[etype].indices
assert torch.equal(indptr, expected_indptr)
assert torch.equal(indices, expected_indices)
def test_compact_csc_format_homo():
N = torch.randint(0, 50, (200,))
expected_original_row_ids = N
csc_formats = gb.CSCFormatBase(
indptr=torch.arange(0, 191, 19), indices=N[10:]
)
dst_nodes = N[:10]
original_row_ids, compacted_csc_formats = gb.compact_csc_format(
csc_formats, dst_nodes
)
indptr = compacted_csc_formats.indptr
indices = N[compacted_csc_formats.indices]
expected_indptr = csc_formats.indptr
expected_indices = csc_formats.indices
assert torch.equal(indptr, expected_indptr)
assert torch.equal(indices, expected_indices)
assert torch.equal(original_row_ids, expected_original_row_ids)
import dgl
import dgl.graphbolt as gb
import gb_test_utils
import pytest
......@@ -255,3 +256,117 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
torch.ge(value, torch.zeros(len(value))),
torch.ones(len(value)),
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_without_dedpulication_Homo(labor):
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
)
graph = gb.from_dglgraph(graph, True)
seed_nodes = torch.LongTensor([0, 3, 4])
itemset = gb.ItemSet(seed_nodes, names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes))
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(item_sampler, graph, fanouts, deduplicate=False)
length = [17, 7]
compacted_indices = [
torch.arange(0, 10) + 7,
torch.arange(0, 4) + 3,
]
indptr = [
torch.tensor([0, 1, 2, 4, 4, 6, 8, 10]),
torch.tensor([0, 1, 2, 4]),
]
seeds = [torch.tensor([0, 3, 4, 5, 2, 2, 4]), torch.tensor([0, 3, 4])]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert len(sampled_subgraph.original_row_node_ids) == length[step]
assert torch.equal(
sampled_subgraph.node_pairs.indices, compacted_indices[step]
)
assert torch.equal(sampled_subgraph.node_pairs.indptr, indptr[step])
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_without_dedpulication_Hetero(labor):
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(item_sampler, graph, fanouts, deduplicate=False)
csc_formats = [
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([4, 5, 6, 7]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6, 8]),
indices=torch.tensor([2, 3, 4, 5, 6, 7, 8, 9]),
),
},
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([0, 1, 2, 3]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0]),
indices=torch.tensor([], dtype=torch.int64),
),
},
]
original_column_node_ids = [
{
"n1": torch.tensor([0, 1, 1, 0]),
"n2": torch.tensor([0, 1]),
},
{
"n1": torch.tensor([], dtype=torch.int64),
"n2": torch.tensor([0, 1]),
},
]
original_row_node_ids = [
{
"n1": torch.tensor([0, 1, 1, 0, 0, 1, 1, 0]),
"n2": torch.tensor([0, 1, 0, 2, 0, 1, 0, 1, 0, 2]),
},
{
"n1": torch.tensor([0, 1, 1, 0]),
"n2": torch.tensor([0, 1]),
},
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype],
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype],
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.node_pairs[etype].indices,
csc_formats[step][etype].indices,
)
assert torch.equal(
sampled_subgraph.node_pairs[etype].indptr,
csc_formats[step][etype].indptr,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment