"graphbolt/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "ae97049e6ccd4cfaebdc911ecd79407fc2a5ffc6"
Unverified Commit 869bfb67 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Change default sampled subgraph output from coo to csc. (#6819)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 0bf51a16
...@@ -60,7 +60,7 @@ class InSubgraphSampler(SubgraphSampler): ...@@ -60,7 +60,7 @@ class InSubgraphSampler(SubgraphSampler):
datapipe, datapipe,
graph, graph,
# TODO: clean up once the migration is done. # TODO: clean up once the migration is done.
output_cscformat=False, output_cscformat=True,
): ):
super().__init__(datapipe) super().__init__(datapipe)
self.graph = graph self.graph = graph
......
...@@ -100,7 +100,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -100,7 +100,7 @@ class NeighborSampler(SubgraphSampler):
prob_name=None, prob_name=None,
deduplicate=True, deduplicate=True,
# TODO: clean up once the migration is done. # TODO: clean up once the migration is done.
output_cscformat=False, output_cscformat=True,
): ):
super().__init__(datapipe) super().__init__(datapipe)
self.graph = graph self.graph = graph
...@@ -270,7 +270,7 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -270,7 +270,7 @@ class LayerNeighborSampler(NeighborSampler):
prob_name=None, prob_name=None,
deduplicate=True, deduplicate=True,
# TODO: clean up once the migration is done. # TODO: clean up once the migration is done.
output_cscformat=False, output_cscformat=True,
): ):
super().__init__( super().__init__(
datapipe, datapipe,
......
...@@ -80,7 +80,9 @@ def test_InSubgraphSampler_node_pairs_homo(): ...@@ -80,7 +80,9 @@ def test_InSubgraphSampler_node_pairs_homo():
batch_size = 1 batch_size = 1
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph) in_subgraph_sampler = gb.InSubgraphSampler(
item_sampler, graph, output_cscformat=False
)
it = iter(in_subgraph_sampler) it = iter(in_subgraph_sampler)
...@@ -156,7 +158,9 @@ def test_InSubgraphSampler_node_pairs_hetero(): ...@@ -156,7 +158,9 @@ def test_InSubgraphSampler_node_pairs_hetero():
batch_size = 2 batch_size = 2
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
in_subgraph_sampler = gb.InSubgraphSampler(item_sampler, graph) in_subgraph_sampler = gb.InSubgraphSampler(
item_sampler, graph, output_cscformat=False
)
it = iter(in_subgraph_sampler) it = iter(in_subgraph_sampler)
......
...@@ -61,15 +61,19 @@ def test_integration_link_prediction(): ...@@ -61,15 +61,19 @@ def test_integration_link_prediction():
expected = [ expected = [
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]), sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
node_pairs=(tensor([5, 4]), tensor([0, 5])), node_pairs=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
indices=tensor([5, 4]),
),
), ),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]), SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0]), original_column_node_ids=tensor([5, 3, 1, 2, 0]),
node_pairs=(tensor([5]), tensor([0])), node_pairs=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1]),
indices=tensor([5]),
),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 1]), positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])), tensor([2, 3, 3, 1])),
...@@ -113,15 +117,19 @@ def test_integration_link_prediction(): ...@@ -113,15 +117,19 @@ def test_integration_link_prediction():
), ),
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]), sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 5, 1]), original_column_node_ids=tensor([3, 4, 0, 5, 1]),
node_pairs=(tensor([1, 3]), tensor([3, 4])), node_pairs=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
), ),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]), SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 5, 1]), original_column_node_ids=tensor([3, 4, 0, 5, 1]),
node_pairs=(tensor([1, 3]), tensor([3, 4])), node_pairs=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 2]), positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])), tensor([0, 0, 1, 1])),
...@@ -164,15 +172,19 @@ def test_integration_link_prediction(): ...@@ -164,15 +172,19 @@ def test_integration_link_prediction():
), ),
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]), sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([1]), tensor([1])), node_pairs=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
), ),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]), SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([1]), tensor([1])), node_pairs=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
)], )],
positive_node_pairs=(tensor([0, 1]), positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])), tensor([0, 0])),
...@@ -262,15 +274,19 @@ def test_integration_node_classification(): ...@@ -262,15 +274,19 @@ def test_integration_node_classification():
expected = [ expected = [
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 4]), sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]), original_column_node_ids=tensor([5, 3, 1, 2]),
node_pairs=(tensor([4, 1, 0, 1]), tensor([0, 1, 2, 3])), node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([4, 1, 0, 1]),
),
), ),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2]), SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]), original_column_node_ids=tensor([5, 3, 1, 2]),
node_pairs=(tensor([0, 1, 0, 1]), tensor([0, 1, 2, 3])), node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([0, 1, 0, 1]),
),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 1]), positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])), tensor([2, 3, 3, 1])),
...@@ -299,15 +315,19 @@ def test_integration_node_classification(): ...@@ -299,15 +315,19 @@ def test_integration_node_classification():
), ),
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]), sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]), original_column_node_ids=tensor([3, 4, 0]),
node_pairs=(tensor([0, 2]), tensor([0, 1])), node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
), ),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]), SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]), original_column_node_ids=tensor([3, 4, 0]),
node_pairs=(tensor([0, 2]), tensor([0, 1])), node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 2]), positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])), tensor([0, 0, 1, 1])),
...@@ -334,15 +354,19 @@ def test_integration_node_classification(): ...@@ -334,15 +354,19 @@ def test_integration_node_classification():
), ),
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4, 0]), sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 4, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([0, 2]), tensor([0, 1])), node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([0, 2]),
),
), ),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]), SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([1, 1]), tensor([0, 1])), node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 1]),
),
)], )],
positive_node_pairs=(tensor([0, 1]), positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])), tensor([0, 0])),
......
...@@ -234,7 +234,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -234,7 +234,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n2": gb.ItemSet(torch.tensor([0]), names="seed_nodes"), "n2": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
"n1": gb.ItemSet(torch.tensor([1]), names="seed_nodes"), "n1": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
} }
) )
...@@ -248,12 +248,12 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -248,12 +248,12 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
for sampledsubgraph in data.sampled_subgraphs: for sampledsubgraph in data.sampled_subgraphs:
for _, value in sampledsubgraph.node_pairs.items(): for _, value in sampledsubgraph.node_pairs.items():
assert torch.equal( assert torch.equal(
torch.ge(value[0], torch.zeros(len(value[0]))), torch.ge(value.indices, torch.zeros(len(value.indices))),
torch.ones(len(value[0])), torch.ones(len(value.indices)),
) )
assert torch.equal( assert torch.equal(
torch.ge(value[1], torch.zeros(len(value[1]))), torch.ge(value.indptr, torch.zeros(len(value.indptr))),
torch.ones(len(value[1])), torch.ones(len(value.indptr)),
) )
for _, value in sampledsubgraph.original_column_node_ids.items(): for _, value in sampledsubgraph.original_column_node_ids.items():
assert torch.equal( assert torch.equal(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment