Unverified Commit 93b39729 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add option to select node_pairs or csc_format. (#6689)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 99fcdfec
......@@ -505,7 +505,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
deduplicate=True,
# TODO: clean up once the migration is done.
output_cscformat=False,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -576,7 +577,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, probs_name
)
if deduplicate is True:
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
......@@ -686,7 +687,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
deduplicate=True,
# TODO: clean up once the migration is done.
output_cscformat=False,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
......@@ -770,7 +772,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
probs_name,
)
if deduplicate:
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
......
......@@ -86,6 +86,8 @@ class NeighborSampler(SubgraphSampler):
replace=False,
prob_name=None,
deduplicate=True,
# TODO: clean up once the migration is done.
output_cscformat=False,
):
super().__init__(datapipe)
self.graph = graph
......@@ -98,6 +100,7 @@ class NeighborSampler(SubgraphSampler):
self.replace = replace
self.prob_name = prob_name
self.deduplicate = deduplicate
self.output_cscformat = output_cscformat
self.sampler = graph.sample_neighbors
def _sample_subgraphs(self, seeds):
......@@ -116,9 +119,11 @@ class NeighborSampler(SubgraphSampler):
self.fanouts[hop],
self.replace,
self.prob_name,
self.deduplicate,
not self.deduplicate or self.output_cscformat,
)
if self.deduplicate:
if self.output_cscformat:
raise RuntimeError("Not implemented yet.")
(
original_row_node_ids,
compacted_node_pairs,
......
......@@ -1720,7 +1720,7 @@ def test_sample_neighbors_homo_csc_format():
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(
nodes, fanouts=torch.LongTensor([2]), deduplicate=False
nodes, fanouts=torch.LongTensor([2]), output_cscformat=True
)
# Verify in subgraph.
......@@ -1774,7 +1774,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False)
subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph.
expected_node_pairs = {
......@@ -1799,7 +1799,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False)
subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph.
expected_node_pairs = {
......@@ -1879,7 +1879,7 @@ def test_sample_neighbors_fanouts_csc_format(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.LongTensor(fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False)
subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph.
assert (
......@@ -1940,7 +1940,7 @@ def test_sample_neighbors_replace_csc_format(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
subgraph = graph.sample_neighbors(
nodes, torch.LongTensor([4]), replace=replace, deduplicate=False
nodes, torch.LongTensor([4]), replace=replace, output_cscformat=True
)
# Verify in subgraph.
......@@ -1984,7 +1984,9 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor):
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts=torch.LongTensor([-1]), deduplicate=False)
subgraph = sampler(
nodes, fanouts=torch.LongTensor([-1]), output_cscformat=True
)
# Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
......@@ -2041,7 +2043,7 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts, deduplicate=False)
subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph.
expected_reverse_edge_ids = {
......@@ -2098,7 +2100,7 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
fanouts=torch.tensor([2]),
replace=replace,
probs_name=probs_name,
deduplicate=False,
output_cscformat=True,
)
# Verify in subgraph.
......@@ -2145,7 +2147,7 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
fanouts=torch.tensor([5]),
replace=replace,
probs_name="probs_or_mask",
deduplicate=False,
output_cscformat=True,
)
# Verify in subgraph.
......@@ -2216,7 +2218,7 @@ def test_sample_neighbors_homo_pick_number_csc_format(
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
deduplicate=False,
output_cscformat=True,
)
sampled_num = subgraph.node_pairs.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 3
......@@ -2308,7 +2310,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
deduplicate=False,
output_cscformat=True,
)
print(subgraph)
if probs_name == "none":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment