Unverified Commit 845864d2 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `SubgraphSampler` to support `seeds`. (#7049)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent ee8b7b39
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from collections import defaultdict from collections import defaultdict
from typing import Dict from typing import Dict
import torch
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from .base import etype_str_to_tuple from .base import etype_str_to_tuple
...@@ -69,10 +70,16 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -69,10 +70,16 @@ class SubgraphSampler(MiniBatchTransformer):
seeds_timestamp = ( seeds_timestamp = (
minibatch.timestamp if hasattr(minibatch, "timestamp") else None minibatch.timestamp if hasattr(minibatch, "timestamp") else None
) )
elif minibatch.seeds is not None:
(
seeds,
seeds_timestamp,
minibatch.compacted_seeds,
) = SubgraphSampler._seeds_preprocess(minibatch)
else: else:
raise ValueError( raise ValueError(
f"Invalid minibatch {minibatch}: Either `node_pairs` or " f"Invalid minibatch {minibatch}: One of `node_pairs`, "
"`seed_nodes` should have a value." "`seed_nodes` and `seeds` should have a value."
) )
minibatch._seed_nodes = seeds minibatch._seed_nodes = seeds
minibatch._seeds_timestamp = seeds_timestamp minibatch._seeds_timestamp = seeds_timestamp
...@@ -226,6 +233,116 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -226,6 +233,116 @@ class SubgraphSampler(MiniBatchTransformer):
""" """
return datapipe.transform(self._sample) return datapipe.transform(self._sample)
@staticmethod
def _seeds_preprocess(minibatch):
"""Preprocess `seeds` in a minibatch to construct `unique_seeds`,
`node_timestamp` and `compacted_seeds` for further sampling. It
optionally incorporates timestamps for temporal graphs, organizing and
compacting seeds based on their types and timestamps.
Parameters
----------
minibatch: MiniBatch
The minibatch.
Returns
-------
unique_seeds: torch.Tensor or Dict[str, torch.Tensor]
A tensor or a dictionary of tensors representing the unique seeds.
In heterogeneous graphs, seeds are returned for each node type.
nodes_timestamp: None or a torch.Tensor or Dict[str, torch.Tensor]
Containing timestamps for each seed. This is only returned if
`minibatch` includes timestamps and the graph is temporal.
compacted_seeds: torch.tensor or a Dict[str, torch.Tensor]
Representation of compacted seeds corresponding to 'seeds', where
all node ids inside are compacted.
"""
use_timestamp = hasattr(minibatch, "timestamp")
seeds = minibatch.seeds
is_heterogeneous = isinstance(seeds, Dict)
if is_heterogeneous:
# Collect nodes from all types of input.
nodes = defaultdict(list)
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for etype, pair in seeds.items():
assert pair.ndim == 1 or (
pair.ndim == 2 and pair.shape[1] == 2
), (
"Only tensor with shape 1*N and N*2 is "
+ f"supported now, but got {pair.shape}."
)
ntypes = etype[:].split(":")[::2]
pair = pair.view(pair.shape[0], -1)
if use_timestamp:
negative_ratio = (
pair.shape[0] // minibatch.timestamp[etype].shape[0] - 1
)
neg_timestamp = minibatch.timestamp[
etype
].repeat_interleave(negative_ratio)
for i, ntype in enumerate(ntypes):
nodes[ntype].append(pair[:, i])
if use_timestamp:
nodes_timestamp[ntype].append(
minibatch.timestamp[etype]
)
nodes_timestamp[ntype].append(neg_timestamp)
# Unique and compact the collected nodes.
if use_timestamp:
(
unique_seeds,
nodes_timestamp,
compacted,
) = compact_temporal_nodes(nodes, nodes_timestamp)
else:
unique_seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
compacted_seeds = {}
# Map back in same order as collect.
for etype, pair in seeds.items():
if pair.ndim == 1:
compacted_seeds[etype] = compacted[etype].pop(0)
else:
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_seeds[etype] = torch.cat((src, dst)).view(2, -1).T
else:
# Collect nodes from all types of input.
nodes = [seeds.view(-1)]
nodes_timestamp = None
if use_timestamp:
# Timestamp for source and destination nodes are the same.
negative_ratio = (
seeds.shape[0] // minibatch.timestamp.shape[0] - 1
)
neg_timestamp = minibatch.timestamp.repeat_interleave(
negative_ratio
)
seeds_timestamp = torch.cat(
(minibatch.timestamp, neg_timestamp)
)
nodes_timestamp = [seeds_timestamp for _ in range(seeds.ndim)]
# Unique and compact the collected nodes.
if use_timestamp:
(
unique_seeds,
nodes_timestamp,
compacted,
) = compact_temporal_nodes(nodes, nodes_timestamp)
else:
unique_seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
# Map back in same order as collect.
compacted_seeds = compacted[0].view(seeds.shape)
return (
unique_seeds,
nodes_timestamp,
compacted_seeds,
)
def sample_subgraphs(self, seeds, seeds_timestamp): def sample_subgraphs(self, seeds, seeds_timestamp):
"""Sample subgraphs from the given seeds, possibly with temporal constraints. """Sample subgraphs from the given seeds, possibly with temporal constraints.
......
...@@ -40,7 +40,7 @@ def _get_sampler(sampler_type): ...@@ -40,7 +40,7 @@ def _get_sampler(sampler_type):
) )
def test_SubgraphSampler_invoke(): def test_SubgraphSampler_invoke_nodes():
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
...@@ -56,7 +56,7 @@ def test_SubgraphSampler_invoke(): ...@@ -56,7 +56,7 @@ def test_SubgraphSampler_invoke():
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_invoke(labor): def test_NeighborSampler_invoke_nodes(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
) )
...@@ -79,7 +79,7 @@ def test_NeighborSampler_invoke(labor): ...@@ -79,7 +79,7 @@ def test_NeighborSampler_invoke(labor):
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_fanouts(labor): def test_NeighborSampler_fanouts_nodes(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
) )
...@@ -108,7 +108,7 @@ def test_NeighborSampler_fanouts(labor): ...@@ -108,7 +108,7 @@ def test_NeighborSampler_fanouts(labor):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Node(sampler_type): def test_SubgraphSampler_Node_seed_nodes(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
...@@ -140,7 +140,7 @@ def to_link_batch(data): ...@@ -140,7 +140,7 @@ def to_link_batch(data):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link(sampler_type): def test_SubgraphSampler_Link_node_pairs(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
...@@ -168,7 +168,7 @@ def test_SubgraphSampler_Link(sampler_type): ...@@ -168,7 +168,7 @@ def test_SubgraphSampler_Link(sampler_type):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link_With_Negative(sampler_type): def test_SubgraphSampler_Link_With_Negative_node_pairs(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
...@@ -219,7 +219,7 @@ def get_hetero_graph(): ...@@ -219,7 +219,7 @@ def get_hetero_graph():
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Node_Hetero(sampler_type): def test_SubgraphSampler_Node_seed_nodes_Hetero(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
items = torch.arange(3) items = torch.arange(3)
...@@ -248,7 +248,7 @@ def test_SubgraphSampler_Node_Hetero(sampler_type): ...@@ -248,7 +248,7 @@ def test_SubgraphSampler_Node_Hetero(sampler_type):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link_Hetero(sampler_type): def test_SubgraphSampler_Link_Hetero_node_pairs(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
...@@ -292,7 +292,7 @@ def test_SubgraphSampler_Link_Hetero(sampler_type): ...@@ -292,7 +292,7 @@ def test_SubgraphSampler_Link_Hetero(sampler_type):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type): def test_SubgraphSampler_Link_Hetero_With_Negative_node_pairs(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
...@@ -337,7 +337,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type): ...@@ -337,7 +337,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type): def test_SubgraphSampler_Link_Hetero_Unknown_Etype_node_pairs(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
...@@ -382,7 +382,9 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type): ...@@ -382,7 +382,9 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type): def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype_node_pairs(
sampler_type,
):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
...@@ -432,7 +434,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type): ...@@ -432,7 +434,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
"replace", "replace",
[False, True], [False, True],
) )
def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace): def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
if F._default_context_str == "gpu" and replace == True: if F._default_context_str == "gpu" and replace == True:
pytest.skip("Sampling with replacement not yet supported on GPU.") pytest.skip("Sampling with replacement not yet supported on GPU.")
...@@ -523,7 +525,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace): ...@@ -523,7 +525,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_without_dedpulication_Homo(sampler_type): def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = dgl.graph( graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4]) ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
...@@ -587,7 +589,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type): ...@@ -587,7 +589,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type): def test_SubgraphSampler_without_dedpulication_Hetero_seed_nodes(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2) items = torch.arange(2)
...@@ -680,7 +682,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type): ...@@ -680,7 +682,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type):
reason="Fails due to different result on the GPU.", reason="Fails due to different result on the GPU.",
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo_cpu(labor): def test_SubgraphSampler_unique_csc_format_Homo_cpu_seed_nodes(labor):
torch.manual_seed(1205) torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4])) graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
graph = gb.from_dglgraph(graph, True).to(F.ctx()) graph = gb.from_dglgraph(graph, True).to(F.ctx())
...@@ -739,7 +741,7 @@ def test_SubgraphSampler_unique_csc_format_Homo_cpu(labor): ...@@ -739,7 +741,7 @@ def test_SubgraphSampler_unique_csc_format_Homo_cpu(labor):
reason="Fails due to different result on the CPU.", reason="Fails due to different result on the CPU.",
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo_gpu(labor): def test_SubgraphSampler_unique_csc_format_Homo_gpu_seed_nodes(labor):
torch.manual_seed(1205) torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4])) graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx()) graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
...@@ -794,7 +796,7 @@ def test_SubgraphSampler_unique_csc_format_Homo_gpu(labor): ...@@ -794,7 +796,7 @@ def test_SubgraphSampler_unique_csc_format_Homo_gpu(labor):
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero(labor): def test_SubgraphSampler_unique_csc_format_Hetero_seed_nodes(labor):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")} {"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")}
...@@ -878,7 +880,7 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor): ...@@ -878,7 +880,7 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor):
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type): def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type):
_check_sampler_type(sampler_type) _check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
items_n1 = torch.tensor([0]) items_n1 = torch.tensor([0])
...@@ -939,3 +941,1259 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type): ...@@ -939,3 +941,1259 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
len(sampled_subgraph.sampled_csc["n2:e2:n1"].indices) len(sampled_subgraph.sampled_csc["n2:e2:n1"].indices)
== indices_len[step]["n2:e2:n1"] == indices_len[step]["n2:e2:n1"]
) )
def test_SubgraphSampler_invoke():
itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
# Invoke via class constructor.
datapipe = gb.SubgraphSampler(item_sampler)
with pytest.raises(NotImplementedError):
next(iter(datapipe))
# Invokde via functional form.
datapipe = item_sampler.sample_subgraph()
with pytest.raises(NotImplementedError):
next(iter(datapipe))
@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_invoke(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
# Invoke via class constructor.
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(item_sampler, graph, fanouts)
assert len(list(datapipe)) == 5
# Invokde via functional form.
if labor:
datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
else:
datapipe = item_sampler.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_fanouts(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
# `fanouts` is a list of tensors.
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
if labor:
datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
else:
datapipe = item_sampler.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5
# `fanouts` is a list of integers.
fanouts = [2 for _ in range(num_layer)]
if labor:
datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
else:
datapipe = item_sampler.sample_neighbor(graph, fanouts)
assert len(list(datapipe)) == 5
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Node(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
items = torch.arange(10)
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
graph.edge_attributes = {
"timestamp": torch.arange(len(graph.indices)).to(F.ctx())
}
items = (items, torch.arange(10))
names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 5
for data in sampler_dp:
assert torch.equal(
data.compacted_seeds, torch.tensor([0, 1]).to(F.ctx())
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
items = torch.arange(20).reshape(-1, 2)
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
graph.edge_attributes = {
"timestamp": torch.arange(len(graph.indices)).to(F.ctx())
}
items = (items, torch.arange(10))
names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now.
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
for data in datapipe:
assert torch.equal(
data.compacted_seeds, torch.tensor([[0, 1], [2, 3]]).to(F.ctx())
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_With_Negative(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
items = torch.arange(20).reshape(-1, 2)
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
graph.edge_attributes = {
"timestamp": torch.arange(len(graph.indices)).to(F.ctx())
}
items = (items, torch.arange(10))
names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now.
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Node_Hetero(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(3)
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(0, 10, (3,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 2
expected_compacted_seeds = {"n2": [torch.tensor([0, 1]), torch.tensor([0])]}
for step, minibatch in enumerate(sampler_dp):
assert len(minibatch.sampled_subgraphs) == num_layer
for etype, compacted_seeds in minibatch.compacted_seeds.items():
assert torch.equal(
compacted_seeds,
expected_compacted_seeds[etype][step].to(F.ctx()),
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "seeds"
second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
second_names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
first_items = (first_items, torch.randint(0, 10, (4,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (6,)))
second_names = (second_names, "timestamp")
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
first_items,
names=first_names,
),
"n2:e2:n1": gb.ItemSet(
second_items,
names=second_names,
),
}
)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now.
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
for data in datapipe:
for compacted_seeds in data.compacted_seeds.values():
if sampler_type == SamplerType.Temporal:
assert torch.equal(
compacted_seeds, torch.tensor([[0, 0], [1, 1]]).to(F.ctx())
)
else:
assert torch.equal(
torch.sort(compacted_seeds.T, dim=1)[0].T,
torch.tensor([[0, 0], [0, 1]]).to(F.ctx()),
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "seeds"
second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
second_names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
first_items = (first_items, torch.randint(0, 10, (4,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (6,)))
second_names = (second_names, "timestamp")
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
first_items,
names=first_names,
),
"n2:e2:n1": gb.ItemSet(
second_items,
names=second_names,
),
}
)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now.
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "seeds"
second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
second_names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
first_items = (first_items, torch.randint(0, 10, (4,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (6,)))
second_names = (second_names, "timestamp")
# "e11" and "e22" are not valid edge types.
itemset = gb.ItemSetDict(
{
"n1:e11:n2": gb.ItemSet(
first_items,
names=first_names,
),
"n2:e22:n1": gb.ItemSet(
second_items,
names=second_names,
),
}
)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now.
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "seeds"
second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
second_names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
first_items = (first_items, torch.randint(0, 10, (4,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (6,)))
second_names = (second_names, "timestamp")
# "e11" and "e22" are not valid edge types.
itemset = gb.ItemSetDict(
{
"n1:e11:n2": gb.ItemSet(
first_items,
names=first_names,
),
"n2:e22:n1": gb.ItemSet(
second_items,
names=second_names,
),
}
)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
# TODO: `exclude_seed_edges` doesn't support `seeds` now.
# datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
@pytest.mark.parametrize(
"replace",
[False, True],
)
def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
_check_sampler_type(sampler_type)
if F._default_context_str == "gpu" and replace == True:
pytest.skip("Sampling with replacement not yet supported on GPU.")
num_nodes = 5
num_edges = 9
num_ntypes = 3
num_etypes = 3
(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
node_type_to_id,
edge_type_to_id,
) = gb_test_utils.random_hetero_graph(
num_nodes, num_edges, num_ntypes, num_etypes
)
node_attributes = {}
edge_attributes = {
"A1": torch.randn(num_edges),
"A2": torch.randn(num_edges),
}
if sampler_type == SamplerType.Temporal:
node_attributes["timestamp"] = torch.randint(0, 10, (num_nodes,))
edge_attributes["timestamp"] = torch.randint(0, 10, (num_edges,))
graph = gb.fused_csc_sampling_graph(
csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
).to(F.ctx())
first_items = torch.tensor([0])
first_names = "seeds"
second_items = torch.tensor([0])
second_names = "seeds"
if sampler_type == SamplerType.Temporal:
first_items = (first_items, torch.randint(0, 10, (1,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (1,)))
second_names = (second_names, "timestamp")
itemset = gb.ItemSetDict(
{
"n2": gb.ItemSet(first_items, names=first_names),
"n1": gb.ItemSet(second_items, names=second_names),
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts, replace=replace)
for data in sampler_dp:
for sampledsubgraph in data.sampled_subgraphs:
for _, value in sampledsubgraph.sampled_csc.items():
assert torch.equal(
torch.ge(
value.indices,
torch.zeros(len(value.indices)).to(F.ctx()),
),
torch.ones(len(value.indices)).to(F.ctx()),
)
assert torch.equal(
torch.ge(
value.indptr, torch.zeros(len(value.indptr)).to(F.ctx())
),
torch.ones(len(value.indptr)).to(F.ctx()),
)
for _, value in sampledsubgraph.original_column_node_ids.items():
assert torch.equal(
torch.ge(value, torch.zeros(len(value)).to(F.ctx())),
torch.ones(len(value)).to(F.ctx()),
)
for _, value in sampledsubgraph.original_row_node_ids.items():
assert torch.equal(
torch.ge(value, torch.zeros(len(value)).to(F.ctx())),
torch.ones(len(value)).to(F.ctx()),
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Homo_Node(sampler_type):
_check_sampler_type(sampler_type)
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
)
graph = gb.from_dglgraph(graph, True).to(F.ctx())
seed_nodes = torch.LongTensor([0, 3, 4])
items = seed_nodes
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(1, 10, (3,)))
names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names)
item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
F.ctx()
)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
if sampler_type == SamplerType.Temporal:
datapipe = sampler(item_sampler, graph, fanouts)
else:
datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
length = [17, 7]
compacted_indices = [
(torch.arange(0, 10) + 7).to(F.ctx()),
(torch.arange(0, 4) + 3).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 4, 4, 6, 8, 10]).to(F.ctx()),
torch.tensor([0, 1, 2, 4]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 2, 2, 3, 4, 4, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert len(sampled_subgraph.original_row_node_ids) == length[step]
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
torch.sort(sampled_subgraph.original_column_node_ids)[0],
seeds[step],
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Hetero_Node(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2)
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(1, 10, (2,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
if sampler_type == SamplerType.Temporal:
datapipe = sampler(item_sampler, graph, fanouts)
else:
datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
csc_formats = [
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([4, 5, 6, 7]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6, 8]),
indices=torch.tensor([2, 3, 4, 5, 6, 7, 8, 9]),
),
},
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([0, 1, 2, 3]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0]),
indices=torch.tensor([], dtype=torch.int64),
),
},
]
original_column_node_ids = [
{
"n1": torch.tensor([0, 1, 1, 0]),
"n2": torch.tensor([0, 1]),
},
{
"n1": torch.tensor([], dtype=torch.int64),
"n2": torch.tensor([0, 1]),
},
]
original_row_node_ids = [
{
"n1": torch.tensor([0, 1, 1, 0, 0, 1, 1, 0]),
"n2": torch.tensor([0, 1, 0, 2, 0, 1, 0, 1, 0, 2]),
},
{
"n1": torch.tensor([0, 1, 1, 0]),
"n2": torch.tensor([0, 1]),
},
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Fails due to different result on the GPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo_Node_cpu(labor):
torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
graph = gb.from_dglgraph(graph, True).to(F.ctx())
seed_nodes = torch.LongTensor([0, 3, 4])
itemset = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
F.ctx()
)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
original_row_node_ids = [
torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
]
compacted_indices = [
torch.tensor([3, 4, 4, 2, 5, 6]).to(F.ctx()),
torch.tensor([3, 4, 4, 2]).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 4, 4, 6]).to(F.ctx()),
torch.tensor([0, 1, 2, 4]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Fails due to different result on the CPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo_Node_gpu(labor):
torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
seed_nodes = torch.LongTensor([0, 3, 4])
itemset = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
F.ctx()
)
num_layer = 2
fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
original_row_node_ids = [
torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
]
compacted_indices = [
torch.tensor([4, 3, 2, 5, 5]).to(F.ctx()),
torch.tensor([4, 3, 2]).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 3, 5, 5]).to(F.ctx()),
torch.tensor([0, 1, 2, 3]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero_Node(labor):
graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict({"n2": gb.ItemSet(torch.arange(2), names="seeds")})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
csc_formats = [
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([0, 1, 1, 0]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([0, 2, 0, 1]),
),
},
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([0, 1, 1, 0]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0]),
indices=torch.tensor([], dtype=torch.int64),
),
},
]
original_column_node_ids = [
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1]),
},
{
"n1": torch.tensor([], dtype=torch.int64),
"n2": torch.tensor([0, 1]),
},
]
original_row_node_ids = [
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1]),
},
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items_n1 = torch.tensor([0])
items_n2 = torch.tensor([1])
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
# All edges can be sampled.
items_n1 = (items_n1, torch.tensor([10]))
items_n2 = (items_n2, torch.tensor([10]))
names = (names, "timestamp")
itemset = gb.ItemSetDict(
{
"n1": gb.ItemSet(items=items_n1, names=names),
"n2": gb.ItemSet(items=items_n2, names=names),
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
# The number of edges to be sampled for each edge types of each node.
fanouts = [torch.LongTensor([2, 1]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts)
if sampler_type == SamplerType.Temporal:
indices_len = [
{
"n1:e1:n2": 4,
"n2:e2:n1": 3,
},
{
"n1:e1:n2": 2,
"n2:e2:n1": 1,
},
]
else:
indices_len = [
{
"n1:e1:n2": 4,
"n2:e2:n1": 2,
},
{
"n1:e1:n2": 2,
"n2:e2:n1": 1,
},
]
for minibatch in sampler_dp:
for step, sampled_subgraph in enumerate(minibatch.sampled_subgraphs):
assert (
len(sampled_subgraph.sampled_csc["n1:e1:n2"].indices)
== indices_len[step]["n1:e1:n2"]
)
assert (
len(sampled_subgraph.sampled_csc["n2:e2:n1"].indices)
== indices_len[step]["n2:e2:n1"]
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Homo_Link(sampler_type):
_check_sampler_type(sampler_type)
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
)
graph = gb.from_dglgraph(graph, True).to(F.ctx())
seed_nodes = torch.LongTensor([[0, 1], [3, 5]])
items = seed_nodes
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(1, 10, (3,)))
names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names)
item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
if sampler_type == SamplerType.Temporal:
datapipe = sampler(item_sampler, graph, fanouts)
else:
datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
length = [13, 7]
compacted_indices = [
(torch.arange(0, 6) + 7).to(F.ctx()),
(torch.arange(0, 3) + 4).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 3, 3, 3, 4, 6]).to(F.ctx()),
torch.tensor([0, 1, 2, 3, 3]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 0, 1, 2, 3, 5, 5]).to(F.ctx()),
torch.tensor([0, 1, 3, 5]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert len(sampled_subgraph.original_row_node_ids) == length[step]
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
torch.sort(sampled_subgraph.original_column_node_ids)[0],
seeds[step],
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Hetero_Link(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2).view(1, 2)
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(1, 10, (2,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict({"n1:e1:n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
if sampler_type == SamplerType.Temporal:
datapipe = sampler(item_sampler, graph, fanouts)
else:
datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
csc_formats = [
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6]),
indices=torch.tensor([3, 4, 5, 6, 7, 8]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6]),
indices=torch.tensor([3, 4, 5, 6, 7, 8]),
),
},
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2]),
indices=torch.tensor([1, 2]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2]),
indices=torch.tensor([1, 2], dtype=torch.int64),
),
},
]
original_column_node_ids = [
{
"n1": torch.tensor([0, 1, 0]),
"n2": torch.tensor([1, 0, 2]),
},
{
"n1": torch.tensor([0]),
"n2": torch.tensor([1]),
},
]
original_row_node_ids = [
{
"n1": torch.tensor([0, 1, 0, 1, 0, 0, 1, 0, 1]),
"n2": torch.tensor([1, 0, 2, 0, 2, 0, 1, 0, 2]),
},
{
"n1": torch.tensor([0, 1, 0]),
"n2": torch.tensor([1, 0, 2]),
},
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Fails due to different result on the GPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo_Link_cpu(labor):
torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
graph = gb.from_dglgraph(graph, True).to(F.ctx())
seed_nodes = torch.LongTensor([[0, 3], [4, 4]])
itemset = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
original_row_node_ids = [
torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
]
compacted_indices = [
torch.tensor([3, 4, 4, 2, 5, 6]).to(F.ctx()),
torch.tensor([3, 4, 4, 2]).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 4, 4, 6]).to(F.ctx()),
torch.tensor([0, 1, 2, 4]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Fails due to different result on the CPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo_Link_gpu(labor):
torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
seed_nodes = torch.LongTensor([[0, 3], [4, 4]])
itemset = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
original_row_node_ids = [
torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
]
compacted_indices = [
torch.tensor([4, 3, 2, 5, 5]).to(F.ctx()),
torch.tensor([4, 3, 2]).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 3, 5, 5]).to(F.ctx()),
torch.tensor([0, 1, 2, 3]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero_Link(labor):
graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict(
{"n1:e1:n2": gb.ItemSet(torch.tensor([[0, 1]]), names="seeds")}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
csc_formats = [
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6]),
indices=torch.tensor([1, 0, 0, 1, 0, 1]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([1, 2, 1, 0]),
),
},
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2]),
indices=torch.tensor([1, 0]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2]),
indices=torch.tensor([1, 2], dtype=torch.int64),
),
},
]
original_column_node_ids = [
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
{
"n1": torch.tensor([0]),
"n2": torch.tensor([1]),
},
]
original_row_node_ids = [
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
torch.sort(sampled_subgraph.original_row_node_ids[ntype])[
0
],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
torch.sort(
sampled_subgraph.original_column_node_ids[ntype]
)[0],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment