Unverified Commit 13e7c2fa authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

[GraphBolt] Improving `subgraph_sampler` tests. (#7047)

parent 8204fe19
import unittest import unittest
import warnings
from enum import Enum from enum import Enum
from functools import partial from functools import partial
...@@ -9,7 +10,6 @@ import dgl ...@@ -9,7 +10,6 @@ import dgl
import dgl.graphbolt as gb import dgl.graphbolt as gb
import pytest import pytest
import torch import torch
from torchdata.datapipes.iter import Mapper
from . import gb_test_utils from . import gb_test_utils
...@@ -22,6 +22,12 @@ def _check_sampler_type(sampler_type): ...@@ -22,6 +22,12 @@ def _check_sampler_type(sampler_type):
) )
def _check_sampler_len(sampler, lenExp):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
assert len(list(sampler)) == lenExp
class SamplerType(Enum): class SamplerType(Enum):
Normal = 0 Normal = 0
Layer = 1 Layer = 1
...@@ -128,7 +134,7 @@ def test_SubgraphSampler_Node_seed_nodes(sampler_type): ...@@ -128,7 +134,7 @@ def test_SubgraphSampler_Node_seed_nodes(sampler_type):
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts) sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 5 _check_sampler_len(sampler_dp, 5)
def to_link_batch(data): def to_link_batch(data):
...@@ -161,7 +167,7 @@ def test_SubgraphSampler_Link_node_pairs(sampler_type): ...@@ -161,7 +167,7 @@ def test_SubgraphSampler_Link_node_pairs(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -190,7 +196,7 @@ def test_SubgraphSampler_Link_With_Negative_node_pairs(sampler_type): ...@@ -190,7 +196,7 @@ def test_SubgraphSampler_Link_With_Negative_node_pairs(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
def get_hetero_graph(): def get_hetero_graph():
...@@ -239,7 +245,9 @@ def test_SubgraphSampler_Node_seed_nodes_Hetero(sampler_type): ...@@ -239,7 +245,9 @@ def test_SubgraphSampler_Node_seed_nodes_Hetero(sampler_type):
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts) sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 2 _check_sampler_len(sampler_dp, 2)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
for minibatch in sampler_dp: for minibatch in sampler_dp:
assert len(minibatch.sampled_subgraphs) == num_layer assert len(minibatch.sampled_subgraphs) == num_layer
...@@ -285,7 +293,7 @@ def test_SubgraphSampler_Link_Hetero_node_pairs(sampler_type): ...@@ -285,7 +293,7 @@ def test_SubgraphSampler_Link_Hetero_node_pairs(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -330,7 +338,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_node_pairs(sampler_type): ...@@ -330,7 +338,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_node_pairs(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -375,7 +383,7 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype_node_pairs(sampler_type): ...@@ -375,7 +383,7 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype_node_pairs(sampler_type):
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -423,7 +431,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype_node_pairs( ...@@ -423,7 +431,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype_node_pairs(
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 _check_sampler_len(datapipe, 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -493,30 +501,26 @@ def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace): ...@@ -493,30 +501,26 @@ def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace):
sampler_dp = sampler(item_sampler, graph, fanouts, replace=replace) sampler_dp = sampler(item_sampler, graph, fanouts, replace=replace)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
for data in sampler_dp: for data in sampler_dp:
for sampledsubgraph in data.sampled_subgraphs: for sampledsubgraph in data.sampled_subgraphs:
for _, value in sampledsubgraph.sampled_csc.items(): for _, value in sampledsubgraph.sampled_csc.items():
for idx in [value.indices, value.indptr]:
assert torch.equal( assert torch.equal(
torch.ge( torch.ge(idx, torch.zeros(len(idx)).to(F.ctx())),
value.indices, torch.ones(len(idx)).to(F.ctx()),
torch.zeros(len(value.indices)).to(F.ctx()),
),
torch.ones(len(value.indices)).to(F.ctx()),
) )
node_ids = [
sampledsubgraph.original_column_node_ids,
sampledsubgraph.original_row_node_ids,
]
for ids in node_ids:
for _, value in ids.items():
assert torch.equal( assert torch.equal(
torch.ge( torch.ge(
value.indptr, torch.zeros(len(value.indptr)).to(F.ctx()) value, torch.zeros(len(value)).to(F.ctx())
), ),
torch.ones(len(value.indptr)).to(F.ctx()),
)
for _, value in sampledsubgraph.original_column_node_ids.items():
assert torch.equal(
torch.ge(value, torch.zeros(len(value)).to(F.ctx())),
torch.ones(len(value)).to(F.ctx()),
)
for _, value in sampledsubgraph.original_row_node_ids.items():
assert torch.equal(
torch.ge(value, torch.zeros(len(value)).to(F.ctx())),
torch.ones(len(value)).to(F.ctx()), torch.ones(len(value)).to(F.ctx()),
) )
...@@ -570,11 +574,16 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type): ...@@ -570,11 +574,16 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type):
torch.tensor([0, 2, 2, 3, 4, 4, 5]).to(F.ctx()), torch.tensor([0, 2, 2, 3, 4, 4, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()), torch.tensor([0, 3, 4]).to(F.ctx()),
] ]
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
for data in datapipe: for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert len(sampled_subgraph.original_row_node_ids) == length[step] assert (
len(sampled_subgraph.original_row_node_ids) == length[step]
)
assert torch.equal( assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step] sampled_subgraph.sampled_csc.indices,
compacted_indices[step],
) )
assert torch.equal( assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step] sampled_subgraph.sampled_csc.indptr, indptr[step]
...@@ -585,6 +594,51 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type): ...@@ -585,6 +594,51 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type):
) )
def _assert_hetero_values(
datapipe, original_row_node_ids, original_column_node_ids, csc_formats
):
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
def _assert_homo_values(
datapipe, original_row_node_ids, compacted_indices, indptr, seeds
):
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
...@@ -655,25 +709,13 @@ def test_SubgraphSampler_without_dedpulication_Hetero_seed_nodes(sampler_type): ...@@ -655,25 +709,13 @@ def test_SubgraphSampler_without_dedpulication_Hetero_seed_nodes(sampler_type):
}, },
] ]
for data in datapipe: with warnings.catch_warnings():
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): warnings.simplefilter("ignore", category=UserWarning)
for ntype in ["n1", "n2"]: _assert_hetero_values(
assert torch.equal( datapipe,
sampled_subgraph.original_row_node_ids[ntype], original_row_node_ids,
original_row_node_ids[step][ntype].to(F.ctx()), original_column_node_ids,
) csc_formats,
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
) )
...@@ -719,20 +761,8 @@ def test_SubgraphSampler_unique_csc_format_Homo_cpu_seed_nodes(labor): ...@@ -719,20 +761,8 @@ def test_SubgraphSampler_unique_csc_format_Homo_cpu_seed_nodes(labor):
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()), torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()), torch.tensor([0, 3, 4]).to(F.ctx()),
] ]
for data in datapipe: _assert_homo_values(
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): datapipe, original_row_node_ids, compacted_indices, indptr, seeds
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
) )
...@@ -778,20 +808,8 @@ def test_SubgraphSampler_unique_csc_format_Homo_gpu_seed_nodes(labor): ...@@ -778,20 +808,8 @@ def test_SubgraphSampler_unique_csc_format_Homo_gpu_seed_nodes(labor):
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()), torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()), torch.tensor([0, 3, 4]).to(F.ctx()),
] ]
for data in datapipe: _assert_homo_values(
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): datapipe, original_row_node_ids, compacted_indices, indptr, seeds
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
) )
...@@ -853,26 +871,8 @@ def test_SubgraphSampler_unique_csc_format_Hetero_seed_nodes(labor): ...@@ -853,26 +871,8 @@ def test_SubgraphSampler_unique_csc_format_Hetero_seed_nodes(labor):
"n2": torch.tensor([0, 1]), "n2": torch.tensor([0, 1]),
}, },
] ]
_assert_hetero_values(
for data in datapipe: datapipe, original_row_node_ids, original_column_node_ids, csc_formats
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
) )
...@@ -886,7 +886,9 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type): ...@@ -886,7 +886,9 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type):
items_n1 = torch.tensor([0]) items_n1 = torch.tensor([0])
items_n2 = torch.tensor([1]) items_n2 = torch.tensor([1])
names = "seed_nodes" names = "seed_nodes"
item_length = 2
if sampler_type == SamplerType.Temporal: if sampler_type == SamplerType.Temporal:
item_length = 3
graph.node_attributes = { graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx()) "timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
} }
...@@ -909,30 +911,23 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type): ...@@ -909,30 +911,23 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type):
fanouts = [torch.LongTensor([2, 1]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2, 1]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts) sampler_dp = sampler(item_sampler, graph, fanouts)
if sampler_type == SamplerType.Temporal:
indices_len = [ indices_len = [
{ {
"n1:e1:n2": 4, "n1:e1:n2": 4,
"n2:e2:n1": 3, "n2:e2:n1": item_length,
},
{
"n1:e1:n2": 2,
"n2:e2:n1": 1,
},
]
else:
indices_len = [
{
"n1:e1:n2": 4,
"n2:e2:n1": 2,
}, },
{ {
"n1:e1:n2": 2, "n1:e1:n2": 2,
"n2:e2:n1": 1, "n2:e2:n1": 1,
}, },
] ]
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
for minibatch in sampler_dp: for minibatch in sampler_dp:
for step, sampled_subgraph in enumerate(minibatch.sampled_subgraphs): for step, sampled_subgraph in enumerate(
minibatch.sampled_subgraphs
):
assert ( assert (
len(sampled_subgraph.sampled_csc["n1:e1:n2"].indices) len(sampled_subgraph.sampled_csc["n1:e1:n2"].indices)
== indices_len[step]["n1:e1:n2"] == indices_len[step]["n1:e1:n2"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment