Unverified Commit a1724c18 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `InSubgraphSampler` to support csc format. (#6724)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent b0d26199
......@@ -301,8 +301,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
self._c_csc_graph.set_edge_attributes(edge_attributes)
def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> FusedSampledSubgraphImpl:
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
# TODO: clean up once the migration is done.
output_cscformat=False,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming
......@@ -361,7 +364,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes)
return self._convert_to_fused_sampled_subgraph(_in_subgraph)
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(_in_subgraph)
else:
return self._convert_to_sampled_subgraph(_in_subgraph)
def _convert_to_fused_sampled_subgraph(
self,
......
......@@ -2,10 +2,13 @@
from torch.utils.data import functional_datapipe
from ..internal import unique_and_compact_node_pairs
from ..internal import (
unique_and_compact_csc_formats,
unique_and_compact_node_pairs,
)
from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import FusedSampledSubgraphImpl
from .sampled_subgraph_impl import FusedSampledSubgraphImpl, SampledSubgraphImpl
__all__ = ["InSubgraphSampler"]
......@@ -56,22 +59,37 @@ class InSubgraphSampler(SubgraphSampler):
self,
datapipe,
graph,
# TODO: clean up once the migration is done.
output_cscformat=False,
):
super().__init__(datapipe)
self.graph = graph
self.output_cscformat = output_cscformat
self.sampler = graph.in_subgraph
def _sample_subgraphs(self, seeds):
subgraph = self.sampler(seeds)
(
original_row_node_ids,
compacted_node_pairs,
) = unique_and_compact_node_pairs(subgraph.node_pairs, seeds)
subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
subgraph = self.sampler(seeds, self.output_cscformat)
if not self.output_cscformat:
(
original_row_node_ids,
compacted_node_pairs,
) = unique_and_compact_node_pairs(subgraph.node_pairs, seeds)
subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_csc_formats,
) = unique_and_compact_csc_formats(subgraph.node_pairs, seeds)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_formats,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
seeds = original_row_node_ids
return (seeds, [subgraph])
......@@ -532,7 +532,7 @@ def test_multiprocessing():
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_homogeneous():
def test_in_subgraph_node_pairs_homogeneous():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -573,7 +573,7 @@ def test_in_subgraph_homogeneous():
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_heterogeneous():
def test_in_subgraph_node_pairs_heterogeneous():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -659,6 +659,137 @@ def test_in_subgraph_heterogeneous():
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_homo():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(indptr, indices)
# Extract in subgraph.
nodes = torch.LongTensor([4, 1, 3])
in_subgraph = graph.in_subgraph(nodes, output_cscformat=True)
# Verify in subgraph.
assert torch.equal(
in_subgraph.node_pairs.indices, torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
)
assert torch.equal(
in_subgraph.node_pairs.indptr, torch.LongTensor([0, 3, 5, 7])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([9, 10, 11, 3, 4, 7, 8])
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_in_subgraph_hetero():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
node_type_0: [0, 1]
node_type_1: [2, 3, 4]
edge_type_0: node_type_0 -> node_type_0
edge_type_1: node_type_0 -> node_type_1
edge_type_2: node_type_1 -> node_type_0
edge_type_3: node_type_1 -> node_type_1
"""
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
ntypes = {
"N0": 0,
"N1": 1,
}
etypes = {
"N0:R0:N0": 0,
"N0:R1:N1": 1,
"N1:R2:N0": 2,
"N1:R3:N1": 3,
}
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
node_type_offset = torch.LongTensor([0, 2, 5])
type_per_edge = torch.LongTensor([0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes))
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
# Extract in subgraph.
nodes = {
"N0": torch.LongTensor([1]),
"N1": torch.LongTensor([2, 1]),
}
in_subgraph = graph.in_subgraph(nodes, output_cscformat=True)
# Verify in subgraph.
assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"].indices, torch.LongTensor([])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"].indptr, torch.LongTensor([0, 0])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"].indices, torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"].indptr, torch.LongTensor([0, 1, 2])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"].indices, torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"].indptr, torch.LongTensor([0, 2])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"].indices, torch.LongTensor([1, 2, 0])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"].indptr, torch.LongTensor([0, 2, 3])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
assert torch.equal(
in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 9, 10, 11, 7, 8])
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
......
......@@ -62,7 +62,7 @@ def test_index_select_csc(indptr_dtype, indices_dtype, idtype, is_pinned):
assert torch.equal(cpu_indices, gpu_indices.cpu())
def test_InSubgraphSampler_homo():
def test_InSubgraphSampler_node_pairs_homo():
"""Original graph in COO:
1 0 1 0 1 0
1 0 0 1 0 1
......@@ -109,7 +109,7 @@ def test_InSubgraphSampler_homo():
assert original_node_pairs(mn) == ([1, 2], [3, 3])
def test_InSubgraphSampler_hetero():
def test_InSubgraphSampler_node_pairs_hetero():
"""Original graph in COO:
1 0 1 0 1 0
1 0 0 1 0 1
......@@ -198,3 +198,175 @@ def test_InSubgraphSampler_hetero():
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs[0], expected_node_pairs[etype][0])
assert torch.equal(pairs[1], expected_node_pairs[etype][1])
def test_InSubgraphSampler_homo():
"""Original graph in COO:
1 0 1 0 1 0
1 0 0 1 0 1
0 1 0 1 0 0
0 1 0 0 1 0
1 0 0 0 0 1
0 0 1 0 1 0
"""
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
graph = gb.fused_csc_sampling_graph(indptr, indices)
seed_nodes = torch.LongTensor([0, 5, 3])
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
batch_size = 1
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
in_subgraph_sampler = gb.InSubgraphSampler(
item_sampler, graph, output_cscformat=True
)
it = iter(in_subgraph_sampler)
def original_indices(minibatch):
sampled_subgraph = minibatch.sampled_subgraphs[0]
_indices = sampled_subgraph.original_row_node_ids[
sampled_subgraph.node_pairs.indices
]
return _indices
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([0]))
assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 3])
)
assert torch.equal(original_indices(mn), torch.tensor([0, 1, 4]))
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([5]))
assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 2])
)
assert torch.equal(original_indices(mn), torch.tensor([1, 4]))
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([3]))
assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 2])
)
assert torch.equal(original_indices(mn), torch.tensor([1, 2]))
def test_InSubgraphSampler_hetero():
"""Original graph in COO:
1 0 1 0 1 0
1 0 0 1 0 1
0 1 0 1 0 0
0 1 0 0 1 0
1 0 0 0 0 1
0 0 1 0 1 0
node_type_0: [0, 1, 2]
node_type_1: [3, 4, 5]
edge_type_0: node_type_0 -> node_type_0
edge_type_1: node_type_0 -> node_type_1
edge_type_2: node_type_1 -> node_type_0
edge_type_3: node_type_1 -> node_type_1
"""
ntypes = {
"N0": 0,
"N1": 1,
}
etypes = {
"N0:R0:N0": 0,
"N0:R1:N1": 1,
"N1:R2:N0": 2,
"N1:R3:N1": 3,
}
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
node_type_offset = torch.LongTensor([0, 3, 6])
type_per_edge = torch.LongTensor([0, 0, 2, 0, 2, 0, 2, 1, 1, 1, 3, 3, 1, 3])
graph = gb.fused_csc_sampling_graph(
csc_indptr=indptr,
indices=indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
item_set = gb.ItemSetDict(
{
"N0": gb.ItemSet(torch.LongTensor([1, 0, 2]), names="seed_nodes"),
"N1": gb.ItemSet(torch.LongTensor([0, 2, 1]), names="seed_nodes"),
}
)
batch_size = 2
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
in_subgraph_sampler = gb.InSubgraphSampler(
item_sampler, graph, output_cscformat=True
)
it = iter(in_subgraph_sampler)
mn = next(it)
assert torch.equal(mn.seed_nodes["N0"], torch.LongTensor([1, 0]))
expected_node_pairs = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]),
indices=torch.LongTensor([2, 1, 0]),
),
"N0:R1:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
),
"N1:R2:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 2]), indices=torch.LongTensor([0, 1])
),
"N1:R3:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr)
mn = next(it)
assert mn.seed_nodes == {
"N0": torch.LongTensor([2]),
"N1": torch.LongTensor([0]),
}
expected_node_pairs = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1]), indices=torch.LongTensor([1])
),
"N0:R1:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 2]), indices=torch.LongTensor([2, 0])
),
"N1:R2:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1]), indices=torch.LongTensor([1])
),
"N1:R3:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 0]), indices=torch.LongTensor([])
),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr)
mn = next(it)
assert torch.equal(mn.seed_nodes["N1"], torch.LongTensor([2, 1]))
expected_node_pairs = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
),
"N0:R1:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 2]), indices=torch.LongTensor([0, 1])
),
"N1:R2:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
),
"N1:R3:N1": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]),
indices=torch.LongTensor([1, 2, 0]),
),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment