"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6b9906f6c2688ab072dd32af47d94cc6ae5409bb"
Unverified Commit 3645e493 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Change default value from `coo` to `csc` in `FusedCSCSamplingGraph`. (#6833)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent ca72965a
...@@ -325,7 +325,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -325,7 +325,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
self, self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]], nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
# TODO: clean up once the migration is done. # TODO: clean up once the migration is done.
output_cscformat=False, output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]: ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Return the subgraph induced on the inbound edges of the given nodes. """Return the subgraph induced on the inbound edges of the given nodes.
...@@ -555,7 +555,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -555,7 +555,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
# TODO: clean up once the migration is done. # TODO: clean up once the migration is done.
output_cscformat=False, output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]: ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph. subgraph.
...@@ -745,7 +745,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -745,7 +745,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
# TODO: clean up once the migration is done. # TODO: clean up once the migration is done.
output_cscformat=False, output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]: ) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
......
...@@ -706,7 +706,7 @@ def test_in_subgraph_node_pairs_homogeneous(): ...@@ -706,7 +706,7 @@ def test_in_subgraph_node_pairs_homogeneous():
# Extract in subgraph. # Extract in subgraph.
nodes = torch.LongTensor([4, 1, 3]) nodes = torch.LongTensor([4, 1, 3])
in_subgraph = graph.in_subgraph(nodes) in_subgraph = graph.in_subgraph(nodes, output_cscformat=False)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal( assert torch.equal(
...@@ -778,7 +778,7 @@ def test_in_subgraph_node_pairs_heterogeneous(): ...@@ -778,7 +778,7 @@ def test_in_subgraph_node_pairs_heterogeneous():
"N0": torch.LongTensor([1]), "N0": torch.LongTensor([1]),
"N1": torch.LongTensor([2, 1]), "N1": torch.LongTensor([2, 1]),
} }
in_subgraph = graph.in_subgraph(nodes) in_subgraph = graph.in_subgraph(nodes, output_cscformat=False)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal( assert torch.equal(
...@@ -991,7 +991,7 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype): ...@@ -991,7 +991,7 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
# 2. Sample with nodes in matched dtype with graph's indices. # 2. Sample with nodes in matched dtype with graph's indices.
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype) nodes = torch.tensor([1, 3, 4], dtype=indices_dtype)
subgraph = sampler(nodes, fanouts) subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0) sampled_num = subgraph.node_pairs[0].size(0)
...@@ -1064,14 +1064,14 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -1064,14 +1064,14 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
"Data type of nodes must be consistent with indices.dtype" "Data type of nodes must be consistent with indices.dtype"
), ),
): ):
_ = sampler(nodes, fanouts) _ = sampler(nodes, fanouts, output_cscformat=False)
# 2. Sample with nodes in matched dtype with graph's indices. # 2. Sample with nodes in matched dtype with graph's indices.
nodes = { nodes = {
"n1": torch.tensor([0], dtype=indices_dtype), "n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype), "n2": torch.tensor([0], dtype=indices_dtype),
} }
subgraph = sampler(nodes, fanouts) subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph. # Verify in subgraph.
expected_node_pairs = { expected_node_pairs = {
...@@ -1111,11 +1111,11 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -1111,11 +1111,11 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
"Data type of nodes must be consistent with indices.dtype" "Data type of nodes must be consistent with indices.dtype"
), ),
): ):
_ = sampler(nodes, fanouts) _ = sampler(nodes, fanouts, output_cscformat=False)
# 2. Sample with nodes in matched dtype with graph's indices. # 2. Sample with nodes in matched dtype with graph's indices.
nodes = {"n1": torch.tensor([0], dtype=indices_dtype)} nodes = {"n1": torch.tensor([0], dtype=indices_dtype)}
subgraph = sampler(nodes, fanouts) subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph. # Verify in subgraph.
expected_node_pairs = { expected_node_pairs = {
...@@ -1194,7 +1194,7 @@ def test_sample_neighbors_fanouts( ...@@ -1194,7 +1194,7 @@ def test_sample_neighbors_fanouts(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.LongTensor(fanouts) fanouts = torch.LongTensor(fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts) subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph. # Verify in subgraph.
assert ( assert (
...@@ -1250,7 +1250,7 @@ def test_sample_neighbors_replace( ...@@ -1250,7 +1250,7 @@ def test_sample_neighbors_replace(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
subgraph = graph.sample_neighbors( subgraph = graph.sample_neighbors(
nodes, torch.LongTensor([4]), replace=replace nodes, torch.LongTensor([4]), replace=replace, output_cscformat=False
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -1288,7 +1288,9 @@ def test_sample_neighbors_return_eids_homo(labor): ...@@ -1288,7 +1288,9 @@ def test_sample_neighbors_return_eids_homo(labor):
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanouts=torch.LongTensor([-1])) subgraph = graph.sample_neighbors(
nodes, fanouts=torch.LongTensor([-1]), output_cscformat=False
)
# Verify in subgraph. # Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][ expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
...@@ -1344,7 +1346,7 @@ def test_sample_neighbors_return_eids_hetero(labor): ...@@ -1344,7 +1346,7 @@ def test_sample_neighbors_return_eids_hetero(labor):
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1]) fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts) subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph. # Verify in subgraph.
expected_reverse_edge_ids = { expected_reverse_edge_ids = {
...@@ -1402,6 +1404,7 @@ def test_sample_neighbors_probs(replace, labor, probs_name): ...@@ -1402,6 +1404,7 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
fanouts=torch.tensor([2]), fanouts=torch.tensor([2]),
replace=replace, replace=replace,
probs_name=probs_name, probs_name=probs_name,
output_cscformat=False,
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -1448,6 +1451,7 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask): ...@@ -1448,6 +1451,7 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
fanouts=torch.tensor([5]), fanouts=torch.tensor([5]),
replace=replace, replace=replace,
probs_name="probs_or_mask", probs_name="probs_or_mask",
output_cscformat=False,
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -1959,6 +1963,7 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name): ...@@ -1959,6 +1963,7 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
fanouts=torch.LongTensor(fanouts), fanouts=torch.LongTensor(fanouts),
replace=replace, replace=replace,
probs_name=probs_name if probs_name != "none" else None, probs_name=probs_name if probs_name != "none" else None,
output_cscformat=False,
) )
sampled_num = subgraph.node_pairs[0].size(0) sampled_num = subgraph.node_pairs[0].size(0)
...@@ -2050,6 +2055,7 @@ def test_sample_neighbors_hetero_pick_number( ...@@ -2050,6 +2055,7 @@ def test_sample_neighbors_hetero_pick_number(
fanouts=torch.LongTensor(fanouts), fanouts=torch.LongTensor(fanouts),
replace=replace, replace=replace,
probs_name=probs_name if probs_name != "none" else None, probs_name=probs_name if probs_name != "none" else None,
output_cscformat=False,
) )
if probs_name == "none": if probs_name == "none":
......
...@@ -1143,6 +1143,7 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -1143,6 +1143,7 @@ def test_OnDiskDataset_preprocess_homogeneous():
subgraph = fused_csc_sampling_graph.sample_neighbors( subgraph = fused_csc_sampling_graph.sample_neighbors(
torch.arange(num_samples), torch.arange(num_samples),
torch.tensor([fanout]), torch.tensor([fanout]),
output_cscformat=False,
) )
assert len(subgraph.node_pairs[0]) <= num_samples assert len(subgraph.node_pairs[0]) <= num_samples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment