Unverified Commit 869bfb67 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Change default sampled subgraph output from coo to csc. (#6819)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 0bf51a16
......@@ -60,7 +60,7 @@ class InSubgraphSampler(SubgraphSampler):
datapipe,
graph,
# TODO: clean up once the migration is done.
output_cscformat=False,
output_cscformat=True,
):
super().__init__(datapipe)
self.graph = graph
......
......@@ -100,7 +100,7 @@ class NeighborSampler(SubgraphSampler):
prob_name=None,
deduplicate=True,
# TODO: clean up once the migration is done.
output_cscformat=False,
output_cscformat=True,
):
super().__init__(datapipe)
self.graph = graph
......@@ -270,7 +270,7 @@ class LayerNeighborSampler(NeighborSampler):
prob_name=None,
deduplicate=True,
# TODO: clean up once the migration is done.
output_cscformat=False,
output_cscformat=True,
):
super().__init__(
datapipe,
......
......@@ -80,7 +80,9 @@ def test_InSubgraphSampler_node_pairs_homo():
batch_size = 1
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
in_subgraph_sampler = gb.InSubgraphSampler(
item_sampler, graph, output_cscformat=False
)
it = iter(in_subgraph_sampler)
......@@ -156,7 +158,9 @@ def test_InSubgraphSampler_node_pairs_hetero():
batch_size = 2
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
in_subgraph_sampler = gb.InSubgraphSampler(
item_sampler, graph, output_cscformat=False
)
it = iter(in_subgraph_sampler)
......
......@@ -61,15 +61,19 @@ def test_integration_link_prediction():
expected = [
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
node_pairs=(tensor([5, 4]), tensor([0, 5])),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
indices=tensor([5, 4]),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0]),
node_pairs=(tensor([5]), tensor([0])),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1]),
indices=tensor([5]),
),
)],
positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
......@@ -113,15 +117,19 @@ def test_integration_link_prediction():
),
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 5, 1]),
node_pairs=(tensor([1, 3]), tensor([3, 4])),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 5, 1]),
node_pairs=(tensor([1, 3]), tensor([3, 4])),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
)],
positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
......@@ -164,15 +172,19 @@ def test_integration_link_prediction():
),
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([1]), tensor([1])),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([1]), tensor([1])),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
)],
positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
......@@ -262,15 +274,19 @@ def test_integration_node_classification():
expected = [
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 4]),
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]),
node_pairs=(tensor([4, 1, 0, 1]), tensor([0, 1, 2, 3])),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([4, 1, 0, 1]),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2]),
),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]),
node_pairs=(tensor([0, 1, 0, 1]), tensor([0, 1, 2, 3])),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([0, 1, 0, 1]),
),
)],
positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
......@@ -299,15 +315,19 @@ def test_integration_node_classification():
),
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
node_pairs=(tensor([0, 2]), tensor([0, 1])),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
node_pairs=(tensor([0, 2]), tensor([0, 1])),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
)],
positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
......@@ -334,15 +354,19 @@ def test_integration_node_classification():
),
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4, 0]),
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([0, 2]), tensor([0, 1])),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([0, 2]),
),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([1, 1]), tensor([0, 1])),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 1]),
),
)],
positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
......
......@@ -234,7 +234,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
itemset = gb.ItemSetDict(
{
"n2": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
"n1": gb.ItemSet(torch.tensor([1]), names="seed_nodes"),
"n1": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
}
)
......@@ -248,12 +248,12 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
for sampledsubgraph in data.sampled_subgraphs:
for _, value in sampledsubgraph.node_pairs.items():
assert torch.equal(
torch.ge(value[0], torch.zeros(len(value[0]))),
torch.ones(len(value[0])),
torch.ge(value.indices, torch.zeros(len(value.indices))),
torch.ones(len(value.indices)),
)
assert torch.equal(
torch.ge(value[1], torch.zeros(len(value[1]))),
torch.ones(len(value[1])),
torch.ge(value.indptr, torch.zeros(len(value.indptr))),
torch.ones(len(value.indptr)),
)
for _, value in sampledsubgraph.original_column_node_ids.items():
assert torch.equal(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment