"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "81d8d7b73f063434d71c83533dc7dcf23e4b5ecc"
Unverified Commit 3645e493 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Change default value from `coo` to `csc` in `FusedCSCSamplingGraph`. (#6833)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent ca72965a
......@@ -325,7 +325,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
# TODO: clean up once the migration is done.
output_cscformat=False,
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Return the subgraph induced on the inbound edges of the given nodes.
......@@ -555,7 +555,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
replace: bool = False,
probs_name: Optional[str] = None,
# TODO: clean up once the migration is done.
output_cscformat=False,
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -745,7 +745,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
replace: bool = False,
probs_name: Optional[str] = None,
# TODO: clean up once the migration is done.
output_cscformat=False,
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
......
......@@ -706,7 +706,7 @@ def test_in_subgraph_node_pairs_homogeneous():
# Extract in subgraph.
nodes = torch.LongTensor([4, 1, 3])
in_subgraph = graph.in_subgraph(nodes)
in_subgraph = graph.in_subgraph(nodes, output_cscformat=False)
# Verify in subgraph.
assert torch.equal(
......@@ -778,7 +778,7 @@ def test_in_subgraph_node_pairs_heterogeneous():
"N0": torch.LongTensor([1]),
"N1": torch.LongTensor([2, 1]),
}
in_subgraph = graph.in_subgraph(nodes)
in_subgraph = graph.in_subgraph(nodes, output_cscformat=False)
# Verify in subgraph.
assert torch.equal(
......@@ -991,7 +991,7 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype)
subgraph = sampler(nodes, fanouts)
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0)
......@@ -1064,14 +1064,14 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts)
_ = sampler(nodes, fanouts, output_cscformat=False)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = {
"n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype),
}
subgraph = sampler(nodes, fanouts)
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
expected_node_pairs = {
......@@ -1111,11 +1111,11 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
"Data type of nodes must be consistent with indices.dtype"
),
):
_ = sampler(nodes, fanouts)
_ = sampler(nodes, fanouts, output_cscformat=False)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes = {"n1": torch.tensor([0], dtype=indices_dtype)}
subgraph = sampler(nodes, fanouts)
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
expected_node_pairs = {
......@@ -1194,7 +1194,7 @@ def test_sample_neighbors_fanouts(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.LongTensor(fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
assert (
......@@ -1250,7 +1250,7 @@ def test_sample_neighbors_replace(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
subgraph = graph.sample_neighbors(
nodes, torch.LongTensor([4]), replace=replace
nodes, torch.LongTensor([4]), replace=replace, output_cscformat=False
)
# Verify in subgraph.
......@@ -1288,7 +1288,9 @@ def test_sample_neighbors_return_eids_homo(labor):
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanouts=torch.LongTensor([-1]))
subgraph = graph.sample_neighbors(
nodes, fanouts=torch.LongTensor([-1]), output_cscformat=False
)
# Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
......@@ -1344,7 +1346,7 @@ def test_sample_neighbors_return_eids_hetero(labor):
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
expected_reverse_edge_ids = {
......@@ -1402,6 +1404,7 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
fanouts=torch.tensor([2]),
replace=replace,
probs_name=probs_name,
output_cscformat=False,
)
# Verify in subgraph.
......@@ -1448,6 +1451,7 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
fanouts=torch.tensor([5]),
replace=replace,
probs_name="probs_or_mask",
output_cscformat=False,
)
# Verify in subgraph.
......@@ -1959,6 +1963,7 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
output_cscformat=False,
)
sampled_num = subgraph.node_pairs[0].size(0)
......@@ -2050,6 +2055,7 @@ def test_sample_neighbors_hetero_pick_number(
fanouts=torch.LongTensor(fanouts),
replace=replace,
probs_name=probs_name if probs_name != "none" else None,
output_cscformat=False,
)
if probs_name == "none":
......
......@@ -1143,6 +1143,7 @@ def test_OnDiskDataset_preprocess_homogeneous():
subgraph = fused_csc_sampling_graph.sample_neighbors(
torch.arange(num_samples),
torch.tensor([fanout]),
output_cscformat=False,
)
assert len(subgraph.node_pairs[0]) <= num_samples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment