Unverified Commit e6f78c10 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `unique_and_compact_csc_format`. (#6703)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 72937574
...@@ -21,6 +21,7 @@ from .subgraph_sampler import * ...@@ -21,6 +21,7 @@ from .subgraph_sampler import *
from .internal import ( from .internal import (
compact_csc_format, compact_csc_format,
unique_and_compact, unique_and_compact,
unique_and_compact_csc_formats,
unique_and_compact_node_pairs, unique_and_compact_node_pairs,
) )
from .utils import add_reverse_edges, exclude_seed_edges from .utils import add_reverse_edges, exclude_seed_edges
......
...@@ -3,7 +3,11 @@ ...@@ -3,7 +3,11 @@
import torch import torch
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from ..internal import compact_csc_format, unique_and_compact_node_pairs from ..internal import (
compact_csc_format,
unique_and_compact_csc_formats,
unique_and_compact_node_pairs,
)
from ..subgraph_sampler import SubgraphSampler from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import FusedSampledSubgraphImpl, SampledSubgraphImpl from .sampled_subgraph_impl import FusedSampledSubgraphImpl, SampledSubgraphImpl
...@@ -123,11 +127,25 @@ class NeighborSampler(SubgraphSampler): ...@@ -123,11 +127,25 @@ class NeighborSampler(SubgraphSampler):
) )
if self.deduplicate: if self.deduplicate:
if self.output_cscformat: if self.output_cscformat:
raise RuntimeError("Not implemented yet.") (
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(
subgraph.node_pairs, seeds
)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
( (
original_row_node_ids, original_row_node_ids,
compacted_node_pairs, compacted_node_pairs,
) = unique_and_compact_node_pairs(subgraph.node_pairs, seeds) ) = unique_and_compact_node_pairs(
subgraph.node_pairs, seeds
)
subgraph = FusedSampledSubgraphImpl( subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs, node_pairs=compacted_node_pairs,
original_column_node_ids=seeds, original_column_node_ids=seeds,
......
...@@ -175,6 +175,116 @@ def unique_and_compact_node_pairs( ...@@ -175,6 +175,116 @@ def unique_and_compact_node_pairs(
return unique_nodes, compacted_node_pairs return unique_nodes, compacted_node_pairs
def unique_and_compact_csc_formats(
csc_formats: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
],
unique_dst_nodes: Union[
torch.Tensor,
Dict[str, torch.Tensor],
],
):
"""
Compact csc formats and return unique nodes (per type).
Parameters
----------
csc_formats : Union[CSCFormatBase, Dict(str, CSCFormatBase)]
CSC formats representing source-destination edges.
- If `csc_formats` is a CSCFormatBase: It means the graph is
homogeneous. Also, indptr and indice in it should be torch.tensor
representing source and destination pairs in csc format. And IDs inside
are homogeneous ids.
- If `csc_formats` is a Dict[str, CSCFormatBase]: The keys
should be edge type and the values should be csc format node pairs.
And IDs inside are heterogeneous ids.
unique_dst_nodes: torch.Tensor or Dict[str, torch.Tensor]
Unique nodes of all destination nodes in the node pairs.
- If `unique_dst_nodes` is a tensor: It means the graph is homogeneous.
- If `csc_formats` is a dictionary: The keys are node type and the
values are corresponding nodes. And IDs inside are heterogeneous ids.
Returns
-------
Tuple[csc_formats, unique_nodes]
The compacted csc formats, where node IDs are replaced with mapped node
IDs, and the unique nodes (per type).
"Compacted csc formats" indicates that the node IDs in the input node
pairs are replaced with mapped node IDs, where each type of node is
mapped to a contiguous space of IDs ranging from 0 to N.
Examples
--------
>>> import dgl.graphbolt as gb
>>> N1 = torch.LongTensor([1, 2, 2])
>>> N2 = torch.LongTensor([5, 5, 6])
>>> unique_dst = {
... "n1": torch.LongTensor([1, 2]),
... "n2": torch.LongTensor([5, 6])}
>>> csc_formats = {
... "n1:e1:n2": CSCFormatBase(indptr=torch.tensor([0, 2, 3]),indices=N1),
... "n2:e2:n1": CSCFormatBase(indptr=torch.tensor([0, 1, 3]),indices=N2)}
>>> unique_nodes, compacted_csc_formats = gb.unique_and_compact_csc_formats(
... csc_formats, unique_dst
... )
>>> print(unique_nodes)
{'n1': tensor([1, 2]), 'n2': tensor([5, 6])}
>>> print(compacted_csc_formats)
{"n1:e1:n2": CSCFormatBase(indptr=torch.tensor([0, 2, 3]),
indices=torch.tensor([0, 1, 1])),
"n2:e2:n1": CSCFormatBase(indptr=torch.tensor([0, 1, 3]),
indices=torch.Longtensor([0, 0, 1]))}
"""
is_homogeneous = not isinstance(csc_formats, dict)
if is_homogeneous:
csc_formats = {"_N:_E:_N": csc_formats}
if unique_dst_nodes is not None:
assert isinstance(
unique_dst_nodes, torch.Tensor
), "Edge type not supported in homogeneous graph."
unique_dst_nodes = {"_N": unique_dst_nodes}
# Collect all source and destination nodes for each node type.
indices = defaultdict(list)
for etype, csc_format in csc_formats.items():
src_type, _, _ = etype_str_to_tuple(etype)
indices[src_type].append(csc_format.indices)
indices = {ntype: torch.cat(nodes) for ntype, nodes in indices.items()}
ntypes = set(indices.keys())
unique_nodes = {}
compacted_indices = {}
dtype = list(indices.values())[0].dtype
default_tensor = torch.tensor([], dtype=dtype)
for ntype in ntypes:
indice = indices.get(ntype, default_tensor)
unique_dst = unique_dst_nodes.get(ntype, default_tensor)
(
unique_nodes[ntype],
compacted_indices[ntype],
_,
) = torch.ops.graphbolt.unique_and_compact(
indice, torch.tensor([], dtype=indice.dtype), unique_dst
)
compacted_csc_formats = {}
# Map back with the same order.
for etype, csc_format in csc_formats.items():
num_elem = csc_format.indices.size(0)
src_type, _, _ = etype_str_to_tuple(etype)
indice = compacted_indices[src_type][:num_elem]
indptr = csc_format.indptr
compacted_csc_formats[etype] = CSCFormatBase(
indptr=indptr, indices=indice
)
compacted_indices[src_type] = compacted_indices[src_type][num_elem:]
# Return singleton for a homogeneous graph.
if is_homogeneous:
compacted_csc_formats = list(compacted_csc_formats.values())[0]
unique_nodes = list(unique_nodes.values())[0]
return unique_nodes, compacted_csc_formats
def compact_csc_format( def compact_csc_format(
csc_formats: Union[CSCFormatBase, Dict[str, CSCFormatBase]], csc_formats: Union[CSCFormatBase, Dict[str, CSCFormatBase]],
dst_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]], dst_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
......
...@@ -184,6 +184,83 @@ def test_incomplete_unique_dst_nodes_(): ...@@ -184,6 +184,83 @@ def test_incomplete_unique_dst_nodes_():
gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes) gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes)
def test_unique_and_compact_csc_formats_hetero():
dst_nodes = {
"n2": torch.tensor([2, 4, 1, 3]),
"n3": torch.tensor([1, 3, 2, 7]),
}
csc_formats = {
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 3, 4, 7, 10]),
indices=torch.tensor([1, 3, 4, 6, 2, 7, 9, 4, 2, 6]),
),
"n1:e2:n3": gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 4, 7, 10]),
indices=torch.tensor([5, 2, 6, 4, 7, 2, 8, 1, 3, 0]),
),
"n2:e3:n3": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6, 8]),
indices=torch.tensor([2, 5, 4, 1, 4, 3, 6, 0]),
),
}
expected_unique_nodes = {
"n1": torch.tensor([1, 3, 4, 6, 2, 7, 9, 5, 8, 0]),
"n2": torch.tensor([2, 4, 1, 3, 5, 6, 0]),
"n3": torch.tensor([1, 3, 2, 7]),
}
expected_csc_formats = {
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 3, 4, 7, 10]),
indices=torch.tensor([0, 1, 2, 3, 4, 5, 6, 2, 4, 3]),
),
"n1:e2:n3": gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 4, 7, 10]),
indices=torch.tensor([7, 4, 3, 2, 5, 4, 8, 0, 1, 9]),
),
"n2:e3:n3": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6, 8]),
indices=torch.tensor([0, 4, 1, 2, 1, 3, 5, 6]),
),
}
unique_nodes, compacted_csc_formats = gb.unique_and_compact_csc_formats(
csc_formats, dst_nodes
)
for ntype, nodes in unique_nodes.items():
expected_nodes = expected_unique_nodes[ntype]
assert torch.equal(nodes, expected_nodes)
for etype, pair in compacted_csc_formats.items():
indices = pair.indices
indptr = pair.indptr
expected_indices = expected_csc_formats[etype].indices
expected_indptr = expected_csc_formats[etype].indptr
assert torch.equal(indices, expected_indices)
assert torch.equal(indptr, expected_indptr)
def test_unique_and_compact_csc_formats_homo():
seeds = torch.tensor([1, 3, 5, 2, 6])
indptr = torch.tensor([0, 2, 4, 6, 7, 10, 11])
indices = torch.tensor([2, 3, 1, 4, 5, 2, 5, 1, 4, 4, 6])
csc_formats = gb.CSCFormatBase(indptr=indptr, indices=indices)
expected_unique_nodes = torch.tensor([1, 3, 5, 2, 6, 4])
expected_indptr = indptr
expected_indices = torch.tensor([3, 1, 0, 5, 2, 3, 2, 0, 5, 5, 4])
unique_nodes, compacted_csc_formats = gb.unique_and_compact_csc_formats(
csc_formats, seeds
)
indptr = compacted_csc_formats.indptr
indices = compacted_csc_formats.indices
assert torch.equal(indptr, expected_indptr)
assert torch.equal(indices, expected_indices)
assert torch.equal(unique_nodes, expected_unique_nodes)
def test_compact_csc_format_hetero(): def test_compact_csc_format_hetero():
N1 = torch.randint(0, 50, (30,)) N1 = torch.randint(0, 50, (30,))
N2 = torch.randint(0, 50, (20,)) N2 = torch.randint(0, 50, (20,))
......
...@@ -370,3 +370,135 @@ def test_SubgraphSampler_without_dedpulication_Hetero(labor): ...@@ -370,3 +370,135 @@ def test_SubgraphSampler_without_dedpulication_Hetero(labor):
sampled_subgraph.node_pairs[etype].indptr, sampled_subgraph.node_pairs[etype].indptr,
csc_formats[step][etype].indptr, csc_formats[step][etype].indptr,
) )
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo(labor):
torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
graph = gb.from_dglgraph(graph, True)
seed_nodes = torch.LongTensor([0, 3, 4])
itemset = gb.ItemSet(seed_nodes, names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes))
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
replace=False,
deduplicate=True,
output_cscformat=True,
)
original_row_node_ids = [
torch.tensor([0, 3, 4, 5, 2, 6, 7]),
torch.tensor([0, 3, 4, 5, 2]),
]
compacted_indices = [
torch.tensor([3, 4, 4, 2, 5, 6]),
torch.tensor([3, 4, 4, 2]),
]
indptr = [
torch.tensor([0, 1, 2, 4, 4, 6]),
torch.tensor([0, 1, 2, 4]),
]
seeds = [torch.tensor([0, 3, 4, 5, 2]), torch.tensor([0, 3, 4])]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.node_pairs.indices, compacted_indices[step]
)
assert torch.equal(sampled_subgraph.node_pairs.indptr, indptr[step])
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero(labor):
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
output_cscformat=True,
)
csc_formats = [
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([0, 1, 1, 0]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([0, 2, 0, 1]),
),
},
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([0, 1, 1, 0]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0]),
indices=torch.tensor([], dtype=torch.int64),
),
},
]
original_column_node_ids = [
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1]),
},
{
"n1": torch.tensor([], dtype=torch.int64),
"n2": torch.tensor([0, 1]),
},
]
original_row_node_ids = [
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1]),
},
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype],
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype],
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.node_pairs[etype].indices,
csc_formats[step][etype].indices,
)
assert torch.equal(
sampled_subgraph.node_pairs[etype].indptr,
csc_formats[step][etype].indptr,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment