Unverified Commit a1724c18 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `InSubgraphSampler` to support csc format. (#6724)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent b0d26199
...@@ -301,8 +301,11 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -301,8 +301,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
self._c_csc_graph.set_edge_attributes(edge_attributes) self._c_csc_graph.set_edge_attributes(edge_attributes)
def in_subgraph( def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] self,
) -> FusedSampledSubgraphImpl: nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
# TODO: clean up once the migration is done.
output_cscformat=False,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Return the subgraph induced on the inbound edges of the given nodes. """Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming An in subgraph is equivalent to creating a new graph using the incoming
...@@ -361,7 +364,10 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -361,7 +364,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "Nodes cannot have duplicate values." ), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes) _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(_in_subgraph) return self._convert_to_fused_sampled_subgraph(_in_subgraph)
else:
return self._convert_to_sampled_subgraph(_in_subgraph)
def _convert_to_fused_sampled_subgraph( def _convert_to_fused_sampled_subgraph(
self, self,
......
...@@ -2,10 +2,13 @@ ...@@ -2,10 +2,13 @@
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from ..internal import unique_and_compact_node_pairs from ..internal import (
unique_and_compact_csc_formats,
unique_and_compact_node_pairs,
)
from ..subgraph_sampler import SubgraphSampler from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import FusedSampledSubgraphImpl from .sampled_subgraph_impl import FusedSampledSubgraphImpl, SampledSubgraphImpl
__all__ = ["InSubgraphSampler"] __all__ = ["InSubgraphSampler"]
...@@ -56,13 +59,17 @@ class InSubgraphSampler(SubgraphSampler): ...@@ -56,13 +59,17 @@ class InSubgraphSampler(SubgraphSampler):
self, self,
datapipe, datapipe,
graph, graph,
# TODO: clean up once the migration is done.
output_cscformat=False,
): ):
super().__init__(datapipe) super().__init__(datapipe)
self.graph = graph self.graph = graph
self.output_cscformat = output_cscformat
self.sampler = graph.in_subgraph self.sampler = graph.in_subgraph
def _sample_subgraphs(self, seeds): def _sample_subgraphs(self, seeds):
subgraph = self.sampler(seeds) subgraph = self.sampler(seeds, self.output_cscformat)
if not self.output_cscformat:
( (
original_row_node_ids, original_row_node_ids,
compacted_node_pairs, compacted_node_pairs,
...@@ -73,5 +80,16 @@ class InSubgraphSampler(SubgraphSampler): ...@@ -73,5 +80,16 @@ class InSubgraphSampler(SubgraphSampler):
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids, original_edge_ids=subgraph.original_edge_ids,
) )
else:
(
original_row_node_ids,
compacted_csc_formats,
) = unique_and_compact_csc_formats(subgraph.node_pairs, seeds)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_formats,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
seeds = original_row_node_ids seeds = original_row_node_ids
return (seeds, [subgraph]) return (seeds, [subgraph])
...@@ -532,7 +532,7 @@ def test_multiprocessing(): ...@@ -532,7 +532,7 @@ def test_multiprocessing():
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
def test_in_subgraph_homogeneous(): def test_in_subgraph_node_pairs_homogeneous():
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 1 0 1 0 1
1 0 1 1 0 1 0 1 1 0
...@@ -573,7 +573,7 @@ def test_in_subgraph_homogeneous(): ...@@ -573,7 +573,7 @@ def test_in_subgraph_homogeneous():
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
def test_in_subgraph_heterogeneous(): def test_in_subgraph_node_pairs_heterogeneous():
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 1 0 1 0 1
1 0 1 1 0 1 0 1 1 0
...@@ -659,6 +659,137 @@ def test_in_subgraph_heterogeneous(): ...@@ -659,6 +659,137 @@ def test_in_subgraph_heterogeneous():
) )
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_homo():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(indptr, indices)
# Extract in subgraph.
nodes = torch.LongTensor([4, 1, 3])
in_subgraph = graph.in_subgraph(nodes, output_cscformat=True)
# Verify in subgraph.
assert torch.equal(
in_subgraph.node_pairs.indices, torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
)
assert torch.equal(
in_subgraph.node_pairs.indptr, torch.LongTensor([0, 3, 5, 7])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([9, 10, 11, 3, 4, 7, 8])
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_hetero():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
node_type_0: [0, 1]
node_type_1: [2, 3, 4]
edge_type_0: node_type_0 -> node_type_0
edge_type_1: node_type_0 -> node_type_1
edge_type_2: node_type_1 -> node_type_0
edge_type_3: node_type_1 -> node_type_1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
ntypes = {
"N0": 0,
"N1": 1,
}
etypes = {
"N0:R0:N0": 0,
"N0:R1:N1": 1,
"N1:R2:N0": 2,
"N1:R3:N1": 3,
}
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
node_type_offset = torch.LongTensor([0, 2, 5])
type_per_edge = torch.LongTensor([0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes))
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Extract in subgraph.
nodes = {
"N0": torch.LongTensor([1]),
"N1": torch.LongTensor([2, 1]),
}
in_subgraph = graph.in_subgraph(nodes, output_cscformat=True)
# Verify in subgraph.
assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"].indices, torch.LongTensor([])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"].indptr, torch.LongTensor([0, 0])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"].indices, torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"].indptr, torch.LongTensor([0, 1, 2])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"].indices, torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"].indptr, torch.LongTensor([0, 2])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"].indices, torch.LongTensor([1, 2, 0])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"].indptr, torch.LongTensor([0, 2, 3])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 9, 10, 11, 7, 8])
)
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
......
...@@ -62,7 +62,7 @@ def test_index_select_csc(indptr_dtype, indices_dtype, idtype, is_pinned): ...@@ -62,7 +62,7 @@ def test_index_select_csc(indptr_dtype, indices_dtype, idtype, is_pinned):
assert torch.equal(cpu_indices, gpu_indices.cpu()) assert torch.equal(cpu_indices, gpu_indices.cpu())
def test_InSubgraphSampler_homo(): def test_InSubgraphSampler_node_pairs_homo():
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 0 1 0 1 0 1 0
1 0 0 1 0 1 1 0 0 1 0 1
...@@ -109,7 +109,7 @@ def test_InSubgraphSampler_homo(): ...@@ -109,7 +109,7 @@ def test_InSubgraphSampler_homo():
assert original_node_pairs(mn) == ([1, 2], [3, 3]) assert original_node_pairs(mn) == ([1, 2], [3, 3])
def test_InSubgraphSampler_hetero(): def test_InSubgraphSampler_node_pairs_hetero():
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 0 1 0 1 0 1 0
1 0 0 1 0 1 1 0 0 1 0 1
...@@ -198,3 +198,175 @@ def test_InSubgraphSampler_hetero(): ...@@ -198,3 +198,175 @@ def test_InSubgraphSampler_hetero():
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items(): for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs[0], expected_node_pairs[etype][0]) assert torch.equal(pairs[0], expected_node_pairs[etype][0])
assert torch.equal(pairs[1], expected_node_pairs[etype][1]) assert torch.equal(pairs[1], expected_node_pairs[etype][1])
def test_InSubgraphSampler_homo():
"""Original graph in COO:
1 0 1 0 1 0
1 0 0 1 0 1
0 1 0 1 0 0
0 1 0 0 1 0
1 0 0 0 0 1
0 0 1 0 1 0
"""
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
graph = gb.fused_csc_sampling_graph(indptr, indices)
seed_nodes = torch.LongTensor([0, 5, 3])
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
batch_size = 1
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
in_subgraph_sampler = gb.InSubgraphSampler(
item_sampler, graph, output_cscformat=True
)
it = iter(in_subgraph_sampler)
def original_indices(minibatch):
sampled_subgraph = minibatch.sampled_subgraphs[0]
_indices = sampled_subgraph.original_row_node_ids[
sampled_subgraph.node_pairs.indices
]
return _indices
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([0]))
assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 3])
)
assert torch.equal(original_indices(mn), torch.tensor([0, 1, 4]))
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([5]))
assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 2])
)
assert torch.equal(original_indices(mn), torch.tensor([1, 4]))
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([3]))
assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 2])
)
assert torch.equal(original_indices(mn), torch.tensor([1, 2]))
def test_InSubgraphSampler_hetero():
"""Original graph in COO:
1 0 1 0 1 0
1 0 0 1 0 1
0 1 0 1 0 0
0 1 0 0 1 0
1 0 0 0 0 1
0 0 1 0 1 0
node_type_0: [0, 1, 2]
node_type_1: [3, 4, 5]
edge_type_0: node_type_0 -> node_type_0
edge_type_1: node_type_0 -> node_type_1
edge_type_2: node_type_1 -> node_type_0
edge_type_3: node_type_1 -> node_type_1
"""
ntypes = {
"N0": 0,
"N1": 1,
}
etypes = {
"N0:R0:N0": 0,
"N0:R1:N1": 1,
"N1:R2:N0": 2,
"N1:R3:N1": 3,
}
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
node_type_offset = torch.LongTensor([0, 3, 6])
type_per_edge = torch.LongTensor([0, 0, 2, 0, 2, 0, 2, 1, 1, 1, 3, 3, 1, 3])
graph = gb.fused_csc_sampling_graph(
csc_indptr=indptr,
indices=indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
item_set = gb.ItemSetDict(
{
"N0": gb.ItemSet(torch.LongTensor([1, 0, 2]), names="seed_nodes"),
"N1": gb.ItemSet(torch.LongTensor([0, 2, 1]), names="seed_nodes"),
}
)
batch_size = 2
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
in_subgraph_sampler = gb.InSubgraphSampler(
item_sampler, graph, output_cscformat=True
)
it = iter(in_subgraph_sampler)
mn = next(it)
assert torch.equal(mn.seed_nodes["N0"], torch.LongTensor([1, 0]))
expected_node_pairs = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]),
indices=torch.LongTensor([2, 1, 0]),
),
"N0:R1:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
),
"N1:R2:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 2]), indices=torch.LongTensor([0, 1])
),
"N1:R3:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr)
mn = next(it)
assert mn.seed_nodes == {
"N0": torch.LongTensor([2]),
"N1": torch.LongTensor([0]),
}
expected_node_pairs = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1]), indices=torch.LongTensor([1])
),
"N0:R1:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 2]), indices=torch.LongTensor([2, 0])
),
"N1:R2:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1]), indices=torch.LongTensor([1])
),
"N1:R3:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 0]), indices=torch.LongTensor([])
),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr)
mn = next(it)
assert torch.equal(mn.seed_nodes["N1"], torch.LongTensor([2, 1]))
expected_node_pairs = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
),
"N0:R1:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 2]), indices=torch.LongTensor([0, 1])
),
"N1:R2:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
),
"N1:R3:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]),
indices=torch.LongTensor([1, 2, 0]),
),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment