"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "fcf1193579ac81fe9899cd9dd61b2b3f05873c7a"
Unverified Commit 93b39729 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add option to select node_pairs or csc_format. (#6689)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 99fcdfec
...@@ -505,7 +505,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -505,7 +505,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
deduplicate=True, # TODO: clean up once the migration is done.
output_cscformat=False,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]: ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph. subgraph.
...@@ -576,7 +577,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -576,7 +577,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph = self._sample_neighbors( C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, probs_name nodes, fanouts, replace, probs_name
) )
if deduplicate is True: if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph) return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else: else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
...@@ -686,7 +687,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -686,7 +687,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
deduplicate=True, # TODO: clean up once the migration is done.
output_cscformat=False,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]: ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
...@@ -770,7 +772,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -770,7 +772,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
probs_name, probs_name,
) )
if deduplicate: if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph) return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else: else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
......
...@@ -86,6 +86,8 @@ class NeighborSampler(SubgraphSampler): ...@@ -86,6 +86,8 @@ class NeighborSampler(SubgraphSampler):
replace=False, replace=False,
prob_name=None, prob_name=None,
deduplicate=True, deduplicate=True,
# TODO: clean up once the migration is done.
output_cscformat=False,
): ):
super().__init__(datapipe) super().__init__(datapipe)
self.graph = graph self.graph = graph
...@@ -98,6 +100,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -98,6 +100,7 @@ class NeighborSampler(SubgraphSampler):
self.replace = replace self.replace = replace
self.prob_name = prob_name self.prob_name = prob_name
self.deduplicate = deduplicate self.deduplicate = deduplicate
self.output_cscformat = output_cscformat
self.sampler = graph.sample_neighbors self.sampler = graph.sample_neighbors
def _sample_subgraphs(self, seeds): def _sample_subgraphs(self, seeds):
...@@ -116,9 +119,11 @@ class NeighborSampler(SubgraphSampler): ...@@ -116,9 +119,11 @@ class NeighborSampler(SubgraphSampler):
self.fanouts[hop], self.fanouts[hop],
self.replace, self.replace,
self.prob_name, self.prob_name,
self.deduplicate, not self.deduplicate or self.output_cscformat,
) )
if self.deduplicate: if self.deduplicate:
if self.output_cscformat:
raise RuntimeError("Not implemented yet.")
( (
original_row_node_ids, original_row_node_ids,
compacted_node_pairs, compacted_node_pairs,
......
...@@ -1720,7 +1720,7 @@ def test_sample_neighbors_homo_csc_format(): ...@@ -1720,7 +1720,7 @@ def test_sample_neighbors_homo_csc_format():
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors( subgraph = graph.sample_neighbors(
nodes, fanouts=torch.LongTensor([2]), deduplicate=False nodes, fanouts=torch.LongTensor([2]), output_cscformat=True
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -1774,7 +1774,7 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -1774,7 +1774,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1]) fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False) subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph. # Verify in subgraph.
expected_node_pairs = { expected_node_pairs = {
...@@ -1799,7 +1799,7 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -1799,7 +1799,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1]) fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False) subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph. # Verify in subgraph.
expected_node_pairs = { expected_node_pairs = {
...@@ -1879,7 +1879,7 @@ def test_sample_neighbors_fanouts_csc_format( ...@@ -1879,7 +1879,7 @@ def test_sample_neighbors_fanouts_csc_format(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.LongTensor(fanouts) fanouts = torch.LongTensor(fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False) subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph. # Verify in subgraph.
assert ( assert (
...@@ -1940,7 +1940,7 @@ def test_sample_neighbors_replace_csc_format( ...@@ -1940,7 +1940,7 @@ def test_sample_neighbors_replace_csc_format(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
subgraph = graph.sample_neighbors( subgraph = graph.sample_neighbors(
nodes, torch.LongTensor([4]), replace=replace, deduplicate=False nodes, torch.LongTensor([4]), replace=replace, output_cscformat=True
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -1984,7 +1984,9 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor): ...@@ -1984,7 +1984,9 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor):
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts=torch.LongTensor([-1]), deduplicate=False) subgraph = sampler(
nodes, fanouts=torch.LongTensor([-1]), output_cscformat=True
)
# Verify in subgraph. # Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][ expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
...@@ -2041,7 +2043,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor): ...@@ -2041,7 +2043,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1]) fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False) subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph. # Verify in subgraph.
expected_reverse_edge_ids = { expected_reverse_edge_ids = {
...@@ -2098,7 +2100,7 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name): ...@@ -2098,7 +2100,7 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
fanouts=torch.tensor([2]), fanouts=torch.tensor([2]),
replace=replace, replace=replace,
probs_name=probs_name, probs_name=probs_name,
deduplicate=False, output_cscformat=True,
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -2145,7 +2147,7 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask): ...@@ -2145,7 +2147,7 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
fanouts=torch.tensor([5]), fanouts=torch.tensor([5]),
replace=replace, replace=replace,
probs_name="probs_or_mask", probs_name="probs_or_mask",
deduplicate=False, output_cscformat=True,
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -2216,7 +2218,7 @@ def test_sample_neighbors_homo_pick_number_csc_format( ...@@ -2216,7 +2218,7 @@ def test_sample_neighbors_homo_pick_number_csc_format(
fanouts=torch.LongTensor(fanouts), fanouts=torch.LongTensor(fanouts),
replace=replace, replace=replace,
probs_name=probs_name if probs_name != "none" else None, probs_name=probs_name if probs_name != "none" else None,
deduplicate=False, output_cscformat=True,
) )
sampled_num = subgraph.node_pairs.indices.size(0) sampled_num = subgraph.node_pairs.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 3 assert subgraph.node_pairs.indptr.size(0) == 3
...@@ -2308,7 +2310,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format( ...@@ -2308,7 +2310,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
fanouts=torch.LongTensor(fanouts), fanouts=torch.LongTensor(fanouts),
replace=replace, replace=replace,
probs_name=probs_name if probs_name != "none" else None, probs_name=probs_name if probs_name != "none" else None,
deduplicate=False, output_cscformat=True,
) )
print(subgraph) print(subgraph)
if probs_name == "none": if probs_name == "none":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment