"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "7facedda38da928843e9ed0de1810d45ce1b9224"
Unverified Commit 13e7c2fa authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

[GraphBolt] Improving `subgraph_sampler` tests. (#7047)

parent 8204fe19
import unittest import unittest
import warnings
from enum import Enum from enum import Enum
from functools import partial from functools import partial
...@@ -9,7 +10,6 @@ import dgl ...@@ -9,7 +10,6 @@ import dgl
import dgl.graphbolt as gb import dgl.graphbolt as gb
import pytest import pytest
import torch import torch
from torchdata.datapipes.iter import Mapper
from . import gb_test_utils from . import gb_test_utils
...@@ -22,6 +22,12 @@ def _check_sampler_type(sampler_type): ...@@ -22,6 +22,12 @@ def _check_sampler_type(sampler_type):
) )
def _check_sampler_len(sampler, lenExp):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
assert len(list(sampler)) == lenExp
class SamplerType(Enum): class SamplerType(Enum):
Normal = 0 Normal = 0
Layer = 1 Layer = 1
...@@ -128,7 +134,7 @@ def test_SubgraphSampler_Node_seed_nodes(sampler_type): ...@@ -128,7 +134,7 @@ def test_SubgraphSampler_Node_seed_nodes(sampler_type):
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts) sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 5 _check_sampler_len(sampler_dp, 5)
def to_link_batch(data): def to_link_batch(data):
...@@ -161,7 +167,7 @@ def test_SubgraphSampler_Link_node_pairs(sampler_type): ...@@ -161,7 +167,7 @@ def test_SubgraphSampler_Link_node_pairs(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -190,7 +196,7 @@ def test_SubgraphSampler_Link_With_Negative_node_pairs(sampler_type): ...@@ -190,7 +196,7 @@ def test_SubgraphSampler_Link_With_Negative_node_pairs(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
def get_hetero_graph(): def get_hetero_graph():
...@@ -239,9 +245,11 @@ def test_SubgraphSampler_Node_seed_nodes_Hetero(sampler_type): ...@@ -239,9 +245,11 @@ def test_SubgraphSampler_Node_seed_nodes_Hetero(sampler_type):
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts) sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 2 _check_sampler_len(sampler_dp, 2)
for minibatch in sampler_dp: with warnings.catch_warnings():
assert len(minibatch.sampled_subgraphs) == num_layer warnings.simplefilter("ignore", category=UserWarning)
for minibatch in sampler_dp:
assert len(minibatch.sampled_subgraphs) == num_layer
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -285,7 +293,7 @@ def test_SubgraphSampler_Link_Hetero_node_pairs(sampler_type): ...@@ -285,7 +293,7 @@ def test_SubgraphSampler_Link_Hetero_node_pairs(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -330,7 +338,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_node_pairs(sampler_type): ...@@ -330,7 +338,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_node_pairs(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -375,7 +383,7 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype_node_pairs(sampler_type): ...@@ -375,7 +383,7 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype_node_pairs(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -423,7 +431,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype_node_pairs( ...@@ -423,7 +431,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype_node_pairs(
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -493,32 +501,28 @@ def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace): ...@@ -493,32 +501,28 @@ def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace):
sampler_dp = sampler(item_sampler, graph, fanouts, replace=replace) sampler_dp = sampler(item_sampler, graph, fanouts, replace=replace)
for data in sampler_dp: with warnings.catch_warnings():
for sampledsubgraph in data.sampled_subgraphs: warnings.simplefilter("ignore", category=UserWarning)
for _, value in sampledsubgraph.sampled_csc.items(): for data in sampler_dp:
assert torch.equal( for sampledsubgraph in data.sampled_subgraphs:
torch.ge( for _, value in sampledsubgraph.sampled_csc.items():
value.indices, for idx in [value.indices, value.indptr]:
torch.zeros(len(value.indices)).to(F.ctx()), assert torch.equal(
), torch.ge(idx, torch.zeros(len(idx)).to(F.ctx())),
torch.ones(len(value.indices)).to(F.ctx()), torch.ones(len(idx)).to(F.ctx()),
) )
assert torch.equal( node_ids = [
torch.ge( sampledsubgraph.original_column_node_ids,
value.indptr, torch.zeros(len(value.indptr)).to(F.ctx()) sampledsubgraph.original_row_node_ids,
), ]
torch.ones(len(value.indptr)).to(F.ctx()), for ids in node_ids:
) for _, value in ids.items():
for _, value in sampledsubgraph.original_column_node_ids.items(): assert torch.equal(
assert torch.equal( torch.ge(
torch.ge(value, torch.zeros(len(value)).to(F.ctx())), value, torch.zeros(len(value)).to(F.ctx())
torch.ones(len(value)).to(F.ctx()), ),
) torch.ones(len(value)).to(F.ctx()),
for _, value in sampledsubgraph.original_row_node_ids.items(): )
assert torch.equal(
torch.ge(value, torch.zeros(len(value)).to(F.ctx())),
torch.ones(len(value)).to(F.ctx()),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -570,9 +574,60 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type): ...@@ -570,9 +574,60 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type):
torch.tensor([0, 2, 2, 3, 4, 4, 5]).to(F.ctx()), torch.tensor([0, 2, 2, 3, 4, 4, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()), torch.tensor([0, 3, 4]).to(F.ctx()),
] ]
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert (
len(sampled_subgraph.original_row_node_ids) == length[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices,
compacted_indices[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
torch.sort(sampled_subgraph.original_column_node_ids)[0],
seeds[step],
)
def _assert_hetero_values(
datapipe, original_row_node_ids, original_column_node_ids, csc_formats
):
for data in datapipe: for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert len(sampled_subgraph.original_row_node_ids) == length[step] for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
def _assert_homo_values(
datapipe, original_row_node_ids, compacted_indices, indptr, seeds
):
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal( assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step] sampled_subgraph.sampled_csc.indices, compacted_indices[step]
) )
...@@ -580,8 +635,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type): ...@@ -580,8 +635,7 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type):
sampled_subgraph.sampled_csc.indptr, indptr[step] sampled_subgraph.sampled_csc.indptr, indptr[step]
) )
assert torch.equal( assert torch.equal(
torch.sort(sampled_subgraph.original_column_node_ids)[0], sampled_subgraph.original_column_node_ids, seeds[step]
seeds[step],
) )
...@@ -655,26 +709,14 @@ def test_SubgraphSampler_without_dedpulication_Hetero_seed_nodes(sampler_type): ...@@ -655,26 +709,14 @@ def test_SubgraphSampler_without_dedpulication_Hetero_seed_nodes(sampler_type):
}, },
] ]
for data in datapipe: with warnings.catch_warnings():
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): warnings.simplefilter("ignore", category=UserWarning)
for ntype in ["n1", "n2"]: _assert_hetero_values(
assert torch.equal( datapipe,
sampled_subgraph.original_row_node_ids[ntype], original_row_node_ids,
original_row_node_ids[step][ntype].to(F.ctx()), original_column_node_ids,
) csc_formats,
assert torch.equal( )
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
@unittest.skipIf( @unittest.skipIf(
...@@ -719,21 +761,9 @@ def test_SubgraphSampler_unique_csc_format_Homo_cpu_seed_nodes(labor): ...@@ -719,21 +761,9 @@ def test_SubgraphSampler_unique_csc_format_Homo_cpu_seed_nodes(labor):
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()), torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()), torch.tensor([0, 3, 4]).to(F.ctx()),
] ]
for data in datapipe: _assert_homo_values(
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): datapipe, original_row_node_ids, compacted_indices, indptr, seeds
assert torch.equal( )
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@unittest.skipIf( @unittest.skipIf(
...@@ -778,21 +808,9 @@ def test_SubgraphSampler_unique_csc_format_Homo_gpu_seed_nodes(labor): ...@@ -778,21 +808,9 @@ def test_SubgraphSampler_unique_csc_format_Homo_gpu_seed_nodes(labor):
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()), torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()), torch.tensor([0, 3, 4]).to(F.ctx()),
] ]
for data in datapipe: _assert_homo_values(
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): datapipe, original_row_node_ids, compacted_indices, indptr, seeds
assert torch.equal( )
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
...@@ -853,27 +871,9 @@ def test_SubgraphSampler_unique_csc_format_Hetero_seed_nodes(labor): ...@@ -853,27 +871,9 @@ def test_SubgraphSampler_unique_csc_format_Hetero_seed_nodes(labor):
"n2": torch.tensor([0, 1]), "n2": torch.tensor([0, 1]),
}, },
] ]
_assert_hetero_values(
for data in datapipe: datapipe, original_row_node_ids, original_column_node_ids, csc_formats
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): )
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -886,7 +886,9 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type): ...@@ -886,7 +886,9 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type):
items_n1 = torch.tensor([0]) items_n1 = torch.tensor([0])
items_n2 = torch.tensor([1]) items_n2 = torch.tensor([1])
names = "seed_nodes" names = "seed_nodes"
item_length = 2
if sampler_type == SamplerType.Temporal: if sampler_type == SamplerType.Temporal:
item_length = 3
graph.node_attributes = { graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx()) "timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
} }
...@@ -909,38 +911,31 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type): ...@@ -909,38 +911,31 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type):
fanouts = [torch.LongTensor([2, 1]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2, 1]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts) sampler_dp = sampler(item_sampler, graph, fanouts)
if sampler_type == SamplerType.Temporal: indices_len = [
indices_len = [ {
{ "n1:e1:n2": 4,
"n1:e1:n2": 4, "n2:e2:n1": item_length,
"n2:e2:n1": 3, },
}, {
{ "n1:e1:n2": 2,
"n1:e1:n2": 2, "n2:e2:n1": 1,
"n2:e2:n1": 1, },
}, ]
]
else: with warnings.catch_warnings():
indices_len = [ warnings.simplefilter("ignore", category=UserWarning)
{ for minibatch in sampler_dp:
"n1:e1:n2": 4, for step, sampled_subgraph in enumerate(
"n2:e2:n1": 2, minibatch.sampled_subgraphs
}, ):
{ assert (
"n1:e1:n2": 2, len(sampled_subgraph.sampled_csc["n1:e1:n2"].indices)
"n2:e2:n1": 1, == indices_len[step]["n1:e1:n2"]
}, )
] assert (
for minibatch in sampler_dp: len(sampled_subgraph.sampled_csc["n2:e2:n1"].indices)
for step, sampled_subgraph in enumerate(minibatch.sampled_subgraphs): == indices_len[step]["n2:e2:n1"]
assert ( )
len(sampled_subgraph.sampled_csc["n1:e1:n2"].indices)
== indices_len[step]["n1:e1:n2"]
)
assert (
len(sampled_subgraph.sampled_csc["n2:e2:n1"].indices)
== indices_len[step]["n2:e2:n1"]
)
def test_SubgraphSampler_invoke(): def test_SubgraphSampler_invoke():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment