"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "ff126ae2c2cd8a7f2e5acf931ec89e98d0844292"
Unverified Commit f5981789 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Support temporal sampling in SubgraphSampler. (#6846)

parent d32a5980
...@@ -5,6 +5,7 @@ from .gpu_cached_feature import * ...@@ -5,6 +5,7 @@ from .gpu_cached_feature import *
from .in_subgraph_sampler import * from .in_subgraph_sampler import *
from .legacy_dataset import * from .legacy_dataset import *
from .neighbor_sampler import * from .neighbor_sampler import *
from .temporal_neighbor_sampler import *
from .ondisk_dataset import * from .ondisk_dataset import *
from .ondisk_metadata import * from .ondisk_metadata import *
from .sampled_subgraph_impl import * from .sampled_subgraph_impl import *
......
...@@ -761,8 +761,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -761,8 +761,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
def temporal_sample_neighbors( def temporal_sample_neighbors(
self, self,
nodes: torch.Tensor, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
input_nodes_timestamp: torch.Tensor, input_nodes_timestamp: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
......
...@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler): ...@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler):
self.graph = graph self.graph = graph
self.sampler = graph.in_subgraph self.sampler = graph.in_subgraph
def sample_subgraphs(self, seeds, seeds_timestamp=None): def sample_subgraphs(self, seeds, seeds_timestamp):
subgraph = self.sampler(seeds) subgraph = self.sampler(seeds)
( (
original_row_node_ids, original_row_node_ids,
......
...@@ -117,7 +117,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -117,7 +117,7 @@ class NeighborSampler(SubgraphSampler):
self.deduplicate = deduplicate self.deduplicate = deduplicate
self.sampler = graph.sample_neighbors self.sampler = graph.sample_neighbors
def sample_subgraphs(self, seeds, seeds_timestamp=None): def sample_subgraphs(self, seeds, seeds_timestamp):
subgraphs = [] subgraphs = []
num_layers = len(self.fanouts) num_layers = len(self.fanouts)
# Enrich seeds with all node types. # Enrich seeds with all node types.
......
...@@ -89,7 +89,10 @@ class TemporalNeighborSampler(SubgraphSampler): ...@@ -89,7 +89,10 @@ class TemporalNeighborSampler(SubgraphSampler):
self.edge_timestamp_attr_name = edge_timestamp_attr_name self.edge_timestamp_attr_name = edge_timestamp_attr_name
self.sampler = graph.temporal_sample_neighbors self.sampler = graph.temporal_sample_neighbors
def sample_subgraphs(self, seeds, seeds_timestamp=None): def sample_subgraphs(self, seeds, seeds_timestamp):
assert (
seeds_timestamp is not None
), "seeds_timestamp must be provided for temporal neighbor sampling."
subgraphs = [] subgraphs = []
num_layers = len(self.fanouts) num_layers = len(self.fanouts)
# Enrich seeds with all node types. # Enrich seeds with all node types.
...@@ -117,10 +120,10 @@ class TemporalNeighborSampler(SubgraphSampler): ...@@ -117,10 +120,10 @@ class TemporalNeighborSampler(SubgraphSampler):
original_row_node_ids, original_row_node_ids,
compacted_csc_formats, compacted_csc_formats,
row_timestamps, row_timestamps,
) = compact_csc_format(subgraph.node_pairs, seeds, seeds_timestamp) ) = compact_csc_format(subgraph.sampled_csc, seeds, seeds_timestamp)
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_formats, sampled_csc=compacted_csc_formats,
original_column_node_ids=seeds, original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids, original_edge_ids=subgraph.original_edge_ids,
......
...@@ -61,6 +61,61 @@ def unique_and_compact( ...@@ -61,6 +61,61 @@ def unique_and_compact(
return unique_and_compact_per_type(nodes) return unique_and_compact_per_type(nodes)
def compact_temporal_nodes(nodes, nodes_timestamp):
"""Compact a list of temporal nodes without unique.
Note that since there is no unique, the nodes and nodes_timestamp are simply
concatenated. And the compacted nodes are consecutive numbers starting from
0.
Parameters
----------
nodes : List[torch.Tensor] or Dict[str, List[torch.Tensor]]
List of nodes for compacting.
the compact operator will be done per type
- If `nodes` is a list of tensor: All the tensors will compact together,
usually it is used for homogeneous graph.
- If `nodes` is a list of dictionary: The keys should be node type and
the values should be corresponding nodes, the compact will be done per
type, usually it is used for heterogeneous graph.
nodes_timestamp : List[torch.Tensor] or Dict[str, List[torch.Tensor]]
List of timestamps for compacting.
Returns
-------
Tuple[nodes, nodes_timestamp, compacted_node_list]
The concatenated nodes and nodes_timestamp, and the compacted nodes list,
where IDs inside are replaced with compacted node IDs.
"""
def _compact_per_type(per_type_nodes, per_type_nodes_timestamp):
nums = [node.size(0) for node in per_type_nodes]
per_type_nodes = torch.cat(per_type_nodes)
per_type_nodes_timestamp = torch.cat(per_type_nodes_timestamp)
compacted_nodes = torch.arange(
0,
per_type_nodes.numel(),
dtype=per_type_nodes.dtype,
device=per_type_nodes.device,
)
compacted_nodes = list(compacted_nodes.split(nums))
return per_type_nodes, per_type_nodes_timestamp, compacted_nodes
if isinstance(nodes, dict):
ret_nodes, ret_timestamp, compacted = {}, {}, {}
for ntype, nodes_of_type in nodes.items():
(
ret_nodes[ntype],
ret_timestamp[ntype],
compacted[ntype],
) = _compact_per_type(nodes_of_type, nodes_timestamp[ntype])
return ret_nodes, ret_timestamp, compacted
else:
return _compact_per_type(nodes, nodes_timestamp)
def unique_and_compact_csc_formats( def unique_and_compact_csc_formats(
csc_formats: Union[ csc_formats: Union[
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
...@@ -236,7 +291,8 @@ def compact_csc_format( ...@@ -236,7 +291,8 @@ def compact_csc_format(
A tensor of original row node IDs (per type) of all nodes in the input. A tensor of original row node IDs (per type) of all nodes in the input.
The compacted CSC formats, where node IDs are replaced with mapped node The compacted CSC formats, where node IDs are replaced with mapped node
IDs ranging from 0 to N. IDs ranging from 0 to N.
The source timestamps (per type) of all nodes in the input if `dst_timestamps` is given. The source timestamps (per type) of all nodes in the input if
`dst_timestamps` is given.
Examples Examples
-------- --------
...@@ -318,8 +374,13 @@ def compact_csc_format( ...@@ -318,8 +374,13 @@ def compact_csc_format(
src_timestamps = None src_timestamps = None
if has_timestamp: if has_timestamp:
src_timestamps = _broadcast_timestamps( src_timestamps = torch.cat(
compacted_csc_formats, dst_timestamps [
dst_timestamps,
_broadcast_timestamps(
compacted_csc_formats, dst_timestamps
),
]
) )
else: else:
compacted_csc_formats = {} compacted_csc_formats = {}
......
...@@ -6,7 +6,7 @@ from typing import Dict ...@@ -6,7 +6,7 @@ from typing import Dict
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from .base import etype_str_to_tuple from .base import etype_str_to_tuple
from .internal import unique_and_compact from .internal import compact_temporal_nodes, unique_and_compact
from .minibatch_transformer import MiniBatchTransformer from .minibatch_transformer import MiniBatchTransformer
__all__ = [ __all__ = [
...@@ -40,12 +40,16 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -40,12 +40,16 @@ class SubgraphSampler(MiniBatchTransformer):
if minibatch.node_pairs is not None: if minibatch.node_pairs is not None:
( (
seeds, seeds,
seeds_timestamp,
minibatch.compacted_node_pairs, minibatch.compacted_node_pairs,
minibatch.compacted_negative_srcs, minibatch.compacted_negative_srcs,
minibatch.compacted_negative_dsts, minibatch.compacted_negative_dsts,
) = self._node_pairs_preprocess(minibatch) ) = self._node_pairs_preprocess(minibatch)
elif minibatch.seed_nodes is not None: elif minibatch.seed_nodes is not None:
seeds = minibatch.seed_nodes seeds = minibatch.seed_nodes
seeds_timestamp = (
minibatch.timestamp if hasattr(minibatch, "timestamp") else None
)
else: else:
raise ValueError( raise ValueError(
f"Invalid minibatch {minibatch}: Either `node_pairs` or " f"Invalid minibatch {minibatch}: Either `node_pairs` or "
...@@ -54,10 +58,11 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -54,10 +58,11 @@ class SubgraphSampler(MiniBatchTransformer):
( (
minibatch.input_nodes, minibatch.input_nodes,
minibatch.sampled_subgraphs, minibatch.sampled_subgraphs,
) = self.sample_subgraphs(seeds) ) = self.sample_subgraphs(seeds, seeds_timestamp)
return minibatch return minibatch
def _node_pairs_preprocess(self, minibatch): def _node_pairs_preprocess(self, minibatch):
use_timestamp = hasattr(minibatch, "timestamp")
node_pairs = minibatch.node_pairs node_pairs = minibatch.node_pairs
neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts
has_neg_src = neg_src is not None has_neg_src = neg_src is not None
...@@ -72,20 +77,44 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -72,20 +77,44 @@ class SubgraphSampler(MiniBatchTransformer):
) )
# Collect nodes from all types of input. # Collect nodes from all types of input.
nodes = defaultdict(list) nodes = defaultdict(list)
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for etype, (src, dst) in node_pairs.items(): for etype, (src, dst) in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype) src_type, _, dst_type = etype_str_to_tuple(etype)
nodes[src_type].append(src) nodes[src_type].append(src)
nodes[dst_type].append(dst) nodes[dst_type].append(dst)
if use_timestamp:
nodes_timestamp[src_type].append(minibatch.timestamp[etype])
nodes_timestamp[dst_type].append(minibatch.timestamp[etype])
if has_neg_src: if has_neg_src:
for etype, src in neg_src.items(): for etype, src in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype) src_type, _, _ = etype_str_to_tuple(etype)
nodes[src_type].append(src.view(-1)) nodes[src_type].append(src.view(-1))
if use_timestamp:
nodes_timestamp[src_type].append(
minibatch.timestamp[etype].repeat_interleave(
src.shape[-1]
)
)
if has_neg_dst: if has_neg_dst:
for etype, dst in neg_dst.items(): for etype, dst in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype) _, _, dst_type = etype_str_to_tuple(etype)
nodes[dst_type].append(dst.view(-1)) nodes[dst_type].append(dst.view(-1))
if use_timestamp:
nodes_timestamp[dst_type].append(
minibatch.timestamp[etype].repeat_interleave(
dst.shape[-1]
)
)
# Unique and compact the collected nodes. # Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes) if use_timestamp:
seeds, nodes_timestamp, compacted = compact_temporal_nodes(
nodes, nodes_timestamp
)
else:
seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
( (
compacted_node_pairs, compacted_node_pairs,
compacted_negative_srcs, compacted_negative_srcs,
...@@ -108,12 +137,30 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -108,12 +137,30 @@ class SubgraphSampler(MiniBatchTransformer):
else: else:
# Collect nodes from all types of input. # Collect nodes from all types of input.
nodes = list(node_pairs) nodes = list(node_pairs)
nodes_timestamp = None
if use_timestamp:
# Timestamp for source and destination nodes are the same.
nodes_timestamp = [minibatch.timestamp, minibatch.timestamp]
if has_neg_src: if has_neg_src:
nodes.append(neg_src.view(-1)) nodes.append(neg_src.view(-1))
if use_timestamp:
nodes_timestamp.append(
minibatch.timestamp.repeat_interleave(neg_src.shape[-1])
)
if has_neg_dst: if has_neg_dst:
nodes.append(neg_dst.view(-1)) nodes.append(neg_dst.view(-1))
if use_timestamp:
nodes_timestamp.append(
minibatch.timestamp.repeat_interleave(neg_dst.shape[-1])
)
# Unique and compact the collected nodes. # Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes) if use_timestamp:
seeds, nodes_timestamp, compacted = compact_temporal_nodes(
nodes, nodes_timestamp
)
else:
seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
# Map back in same order as collect. # Map back in same order as collect.
compacted_node_pairs = tuple(compacted[:2]) compacted_node_pairs = tuple(compacted[:2])
compacted = compacted[2:] compacted = compacted[2:]
...@@ -132,13 +179,14 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -132,13 +179,14 @@ class SubgraphSampler(MiniBatchTransformer):
) )
return ( return (
seeds, seeds,
nodes_timestamp,
compacted_node_pairs, compacted_node_pairs,
compacted_negative_srcs if has_neg_src else None, compacted_negative_srcs if has_neg_src else None,
compacted_negative_dsts if has_neg_dst else None, compacted_negative_dsts if has_neg_dst else None,
) )
def sample_subgraphs(self, seeds, seeds_timestamp=None): def sample_subgraphs(self, seeds, seeds_timestamp):
"""Sample subgraphs from the given seeds. """Sample subgraphs from the given seeds, possibly with temporal constraints.
Any subclass of SubgraphSampler should implement this method. Any subclass of SubgraphSampler should implement this method.
...@@ -147,6 +195,11 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -147,6 +195,11 @@ class SubgraphSampler(MiniBatchTransformer):
seeds : Union[torch.Tensor, Dict[str, torch.Tensor]] seeds : Union[torch.Tensor, Dict[str, torch.Tensor]]
The seed nodes. The seed nodes.
seeds_timestamp : Union[torch.Tensor, Dict[str, torch.Tensor]]
The timestamps of the seed nodes. If given, the sampled subgraphs
should not contain any nodes or edges that are newer than the
timestamps of the seed nodes. Default: None.
Returns Returns
------- -------
Union[torch.Tensor, Dict[str, torch.Tensor]] Union[torch.Tensor, Dict[str, torch.Tensor]]
......
import unittest import unittest
from enum import Enum
from functools import partial from functools import partial
import backend as F import backend as F
...@@ -12,6 +14,31 @@ from torchdata.datapipes.iter import Mapper ...@@ -12,6 +14,31 @@ from torchdata.datapipes.iter import Mapper
from . import gb_test_utils from . import gb_test_utils
# Skip all tests on GPU.
pytestmark = pytest.mark.skipif(
F._default_context_str != "cpu",
reason="GraphBolt sampling tests are only supported on CPU.",
)
class SamplerType(Enum):
Normal = 0
Layer = 1
Temporal = 2
def _get_sampler(sampler_type):
if sampler_type == SamplerType.Normal:
return gb.NeighborSampler
if sampler_type == SamplerType.Layer:
return gb.LayerNeighborSampler
return partial(
gb.TemporalNeighborSampler,
node_timestamp_attr_name="timestamp",
edge_timestamp_attr_name="timestamp",
)
def test_SubgraphSampler_invoke(): def test_SubgraphSampler_invoke():
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
...@@ -76,17 +103,29 @@ def test_NeighborSampler_fanouts(labor): ...@@ -76,17 +103,29 @@ def test_NeighborSampler_fanouts(labor):
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize(
def test_SubgraphSampler_Node(labor): "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Node(sampler_type):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
) )
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") items = torch.arange(10)
names = "seed_nodes"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
graph.edge_attributes = {
"timestamp": torch.arange(len(graph.indices)).to(F.ctx())
}
items = (items, torch.arange(10))
names = ("seed_nodes", "timestamp")
itemset = gb.ItemSet(items, names=names)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler sampler = _get_sampler(sampler_type)
sampler_dp = Sampler(item_sampler, graph, fanouts) sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 5 assert len(list(sampler_dp)) == 5
...@@ -95,33 +134,57 @@ def to_link_batch(data): ...@@ -95,33 +134,57 @@ def to_link_batch(data):
return block return block
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize(
def test_SubgraphSampler_Link(labor): "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link(sampler_type):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
) )
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs") items = torch.arange(20).reshape(-1, 2)
names = "node_pairs"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
graph.edge_attributes = {
"timestamp": torch.arange(len(graph.indices)).to(F.ctx())
}
items = (items, torch.arange(10))
names = ("node_pairs", "timestamp")
itemset = gb.ItemSet(items, names=names)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler sampler = _get_sampler(sampler_type)
datapipe = Sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize(
def test_SubgraphSampler_Link_With_Negative(labor): "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_With_Negative(sampler_type):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
) )
itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs") items = torch.arange(20).reshape(-1, 2)
names = "node_pairs"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
graph.edge_attributes = {
"timestamp": torch.arange(len(graph.indices)).to(F.ctx())
}
items = (items, torch.arange(10))
names = ("node_pairs", "timestamp")
itemset = gb.ItemSet(items, names=names)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1) datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler sampler = _get_sampler(sampler_type)
datapipe = Sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
...@@ -148,34 +211,64 @@ def get_hetero_graph(): ...@@ -148,34 +211,64 @@ def get_hetero_graph():
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize(
def test_SubgraphSampler_Node_Hetero(labor): "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Node_Hetero(sampler_type):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict( items = torch.arange(3)
{"n2": gb.ItemSet(torch.arange(3), names="seed_nodes")} names = "seed_nodes"
) if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(0, 10, (3,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler sampler = _get_sampler(sampler_type)
sampler_dp = Sampler(item_sampler, graph, fanouts) sampler_dp = sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 2 assert len(list(sampler_dp)) == 2
for minibatch in sampler_dp: for minibatch in sampler_dp:
assert len(minibatch.sampled_subgraphs) == num_layer assert len(minibatch.sampled_subgraphs) == num_layer
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize(
def test_SubgraphSampler_Link_Hetero(labor): "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero(sampler_type):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "node_pairs"
second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
second_names = "node_pairs"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
first_items = (first_items, torch.randint(0, 10, (4,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (6,)))
second_names = (second_names, "timestamp")
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1:e1:n2": gb.ItemSet( "n1:e1:n2": gb.ItemSet(
torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T, first_items,
names="node_pairs", names=first_names,
), ),
"n2:e2:n1": gb.ItemSet( "n2:e2:n1": gb.ItemSet(
torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T, second_items,
names="node_pairs", names=second_names,
), ),
} }
) )
...@@ -183,24 +276,42 @@ def test_SubgraphSampler_Link_Hetero(labor): ...@@ -183,24 +276,42 @@ def test_SubgraphSampler_Link_Hetero(labor):
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler sampler = _get_sampler(sampler_type)
datapipe = Sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize(
def test_SubgraphSampler_Link_Hetero_With_Negative(labor): "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "node_pairs"
second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
second_names = "node_pairs"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
first_items = (first_items, torch.randint(0, 10, (4,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (6,)))
second_names = (second_names, "timestamp")
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1:e1:n2": gb.ItemSet( "n1:e1:n2": gb.ItemSet(
torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T, first_items,
names="node_pairs", names=first_names,
), ),
"n2:e2:n1": gb.ItemSet( "n2:e2:n1": gb.ItemSet(
torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T, second_items,
names="node_pairs", names=second_names,
), ),
} }
) )
...@@ -209,8 +320,8 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor): ...@@ -209,8 +320,8 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1) datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler sampler = _get_sampler(sampler_type)
datapipe = Sampler(datapipe, graph, fanouts) datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
...@@ -219,8 +330,11 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor): ...@@ -219,8 +330,11 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
F._default_context_str != "cpu", F._default_context_str != "cpu",
reason="Sampling with replacement not yet supported on GPU.", reason="Sampling with replacement not yet supported on GPU.",
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize(
def test_SubgraphSampler_Random_Hetero_Graph(labor): "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Random_Hetero_Graph(sampler_type):
num_nodes = 5 num_nodes = 5
num_edges = 9 num_edges = 9
num_ntypes = 3 num_ntypes = 3
...@@ -235,10 +349,14 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -235,10 +349,14 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
) = gb_test_utils.random_hetero_graph( ) = gb_test_utils.random_hetero_graph(
num_nodes, num_edges, num_ntypes, num_etypes num_nodes, num_edges, num_ntypes, num_etypes
) )
node_attributes = {}
edge_attributes = { edge_attributes = {
"A1": torch.randn(num_edges), "A1": torch.randn(num_edges),
"A2": torch.randn(num_edges), "A2": torch.randn(num_edges),
} }
if sampler_type == SamplerType.Temporal:
node_attributes["timestamp"] = torch.randint(0, 10, (num_nodes,))
edge_attributes["timestamp"] = torch.randint(0, 10, (num_edges,))
graph = gb.fused_csc_sampling_graph( graph = gb.fused_csc_sampling_graph(
csc_indptr, csc_indptr,
indices, indices,
...@@ -246,21 +364,31 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -246,21 +364,31 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id, node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id, edge_type_to_id=edge_type_to_id,
node_attributes=node_attributes,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
).to(F.ctx()) ).to(F.ctx())
first_items = torch.tensor([0])
first_names = "seed_nodes"
second_items = torch.tensor([0])
second_names = "seed_nodes"
if sampler_type == SamplerType.Temporal:
first_items = (first_items, torch.randint(0, 10, (1,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (1,)))
second_names = (second_names, "timestamp")
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n2": gb.ItemSet(torch.tensor([0]), names="seed_nodes"), "n2": gb.ItemSet(first_items, names=first_names),
"n1": gb.ItemSet(torch.tensor([0]), names="seed_nodes"), "n1": gb.ItemSet(second_items, names=second_names),
} }
) )
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler sampler = _get_sampler(sampler_type)
sampler_dp = Sampler(item_sampler, graph, fanouts, replace=True) sampler_dp = sampler(item_sampler, graph, fanouts, replace=True)
for data in sampler_dp: for data in sampler_dp:
for sampledsubgraph in data.sampled_subgraphs: for sampledsubgraph in data.sampled_subgraphs:
...@@ -289,23 +417,40 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -289,23 +417,40 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
F._default_context_str != "cpu", F._default_context_str != "cpu",
reason="Fails due to randomness on the GPU.", reason="Fails due to randomness on the GPU.",
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize(
def test_SubgraphSampler_without_dedpulication_Homo(labor): "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
graph = dgl.graph( graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4]) ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
) )
graph = gb.from_dglgraph(graph, True).to(F.ctx()) graph = gb.from_dglgraph(graph, True).to(F.ctx())
seed_nodes = torch.LongTensor([0, 3, 4]) seed_nodes = torch.LongTensor([0, 3, 4])
items = seed_nodes
names = "seed_nodes"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(0, 10, (3,)))
names = (names, "timestamp")
itemset = gb.ItemSet(seed_nodes, names="seed_nodes") itemset = gb.ItemSet(items, names=names)
item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to( item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
F.ctx() F.ctx()
) )
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler sampler = _get_sampler(sampler_type)
datapipe = Sampler(item_sampler, graph, fanouts, deduplicate=False) if sampler_type == SamplerType.Temporal:
datapipe = sampler(item_sampler, graph, fanouts)
else:
datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
length = [17, 7] length = [17, 7]
compacted_indices = [ compacted_indices = [
...@@ -334,17 +479,32 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor): ...@@ -334,17 +479,32 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize(
def test_SubgraphSampler_without_dedpulication_Hetero(labor): "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict( items = torch.arange(2)
{"n2": gb.ItemSet(torch.arange(2), names="seed_nodes")} names = "seed_nodes"
) if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(0, 10, (2,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler sampler = _get_sampler(sampler_type)
datapipe = Sampler(item_sampler, graph, fanouts, deduplicate=False) if sampler_type == SamplerType.Temporal:
datapipe = sampler(item_sampler, graph, fanouts)
else:
datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
csc_formats = [ csc_formats = [
{ {
"n1:e1:n2": gb.CSCFormatBase( "n1:e1:n2": gb.CSCFormatBase(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment