"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "618072bf5d0da15492003e80af1c913c5f38f76d"
Unverified Commit 3200b88b authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Add TemporalNeighborSampler. (#6814)

parent f758c7c1
...@@ -867,15 +867,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -867,15 +867,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
node_timestamp_attr_name, node_timestamp_attr_name,
edge_timestamp_attr_name, edge_timestamp_attr_name,
) )
# Broadcast the input nodes' timestamp to the sampled neighbors. return self._convert_to_sampled_subgraph(C_sampled_subgraph)
sampled_count = torch.diff(C_sampled_subgraph.indptr)
neighbors_timestamp = input_nodes_timestamp.repeat_interleave(
sampled_count
)
return (
self._convert_to_sampled_subgraph(C_sampled_subgraph),
neighbors_timestamp,
)
def sample_negative_edges_uniform( def sample_negative_edges_uniform(
self, edge_type, node_pairs, negative_ratio self, edge_type, node_pairs, negative_ratio
......
...@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler): ...@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler):
self.graph = graph self.graph = graph
self.sampler = graph.in_subgraph self.sampler = graph.in_subgraph
def sample_subgraphs(self, seeds): def sample_subgraphs(self, seeds, seeds_timestamp=None):
subgraph = self.sampler(seeds) subgraph = self.sampler(seeds)
( (
original_row_node_ids, original_row_node_ids,
......
...@@ -112,7 +112,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -112,7 +112,7 @@ class NeighborSampler(SubgraphSampler):
self.deduplicate = deduplicate self.deduplicate = deduplicate
self.sampler = graph.sample_neighbors self.sampler = graph.sample_neighbors
def sample_subgraphs(self, seeds): def sample_subgraphs(self, seeds, seeds_timestamp=None):
subgraphs = [] subgraphs = []
num_layers = len(self.fanouts) num_layers = len(self.fanouts)
# Enrich seeds with all node types. # Enrich seeds with all node types.
......
"""Temporal neighbor subgraph samplers for GraphBolt."""
import torch
from torch.utils.data import functional_datapipe
from ..internal import compact_csc_format
from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import SampledSubgraphImpl
__all__ = ["TemporalNeighborSampler"]
@functional_datapipe("temporal_sample_neighbor")
class TemporalNeighborSampler(SubgraphSampler):
"""Temporally sample neighbor edges from a graph and return sampled
subgraphs.
Functional name: :obj:`temporal_sample_neighbor`.
Neighbor sampler is responsible for sampling a subgraph from given data. It
returns an induced subgraph along with compacted information. In the
context of a node classification task, the neighbor sampler directly
utilizes the nodes provided as seed nodes. However, in scenarios involving
link prediction, the process needs another pre-peocess operation. That is,
gathering unique nodes from the given node pairs, encompassing both
positive and negative node pairs, and employs these nodes as the seed nodes
for subsequent steps.
Parameters
----------
datapipe : DataPipe
The datapipe.
graph : FusedCSCSamplingGraph
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor] or list[int]
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
Note: The fanout order is from the outermost layer to innermost layer.
For example, the fanout '[15, 10, 5]' means that 15 to the outermost
layer, 10 to the intermediate layer and 5 corresponds to the innermost
layer.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
prob_name: str, optional
The name of an edge attribute used as the weights of sampling for
each node. This attribute tensor should contain (unnormalized)
probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges.
node_timestamp_attr_name: str, optional
The name of an node attribute used as the timestamps of nodes.
It must be a 1D integer tensor, with the number of elements
equalling the total number of nodes.
edge_timestamp_attr_name: str, optional
The name of an edge attribute used as the timestamps of edges.
It must be a 1D integer tensor, with the number of elements
equalling the total number of edges.
Examples
-------
TODO(zhenkun) : Add an example after the API to pass timestamps is finalized.
"""
def __init__(
self,
datapipe,
graph,
fanouts,
replace=False,
prob_name=None,
node_timestamp_attr_name=None,
edge_timestamp_attr_name=None,
):
super().__init__(datapipe)
self.graph = graph
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
self.replace = replace
self.prob_name = prob_name
self.node_timestamp_attr_name = node_timestamp_attr_name
self.edge_timestamp_attr_name = edge_timestamp_attr_name
self.sampler = graph.temporal_sample_neighbors
def sample_subgraphs(self, seeds, seeds_timestamp=None):
subgraphs = []
num_layers = len(self.fanouts)
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(self.graph.node_type_to_id.keys())
seeds = {
ntype: seeds.get(ntype, torch.LongTensor([]))
for ntype in ntypes
}
seeds_timestamp = {
ntype: seeds_timestamp.get(ntype, torch.LongTensor([]))
for ntype in ntypes
}
for hop in range(num_layers):
subgraph = self.sampler(
seeds,
seeds_timestamp,
self.fanouts[hop],
self.replace,
self.prob_name,
self.node_timestamp_attr_name,
self.edge_timestamp_attr_name,
)
(
original_row_node_ids,
compacted_csc_formats,
row_timestamps,
) = compact_csc_format(subgraph.node_pairs, seeds, seeds_timestamp)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_formats,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
subgraphs.insert(0, subgraph)
seeds = original_row_node_ids
seeds_timestamp = row_timestamps
return seeds, subgraphs
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import copy import copy
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -299,12 +299,31 @@ def unique_and_compact_csc_formats( ...@@ -299,12 +299,31 @@ def unique_and_compact_csc_formats(
return unique_nodes, compacted_csc_formats return unique_nodes, compacted_csc_formats
def _broadcast_timestamps(csc, dst_timestamps):
"""Broadcast the timestamp of each destination node to its corresponding
source nodes."""
count = torch.diff(csc.indptr)
src_timestamps = torch.repeat_interleave(dst_timestamps, count)
return src_timestamps
def compact_csc_format( def compact_csc_format(
csc_formats: Union[CSCFormatBase, Dict[str, CSCFormatBase]], csc_formats: Union[CSCFormatBase, Dict[str, CSCFormatBase]],
dst_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]], dst_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
dst_timestamps: Optional[
Union[torch.Tensor, Dict[str, torch.Tensor]]
] = None,
): ):
""" """
Compact csc formats and return original_row_ids (per type). Relabel the row (source) IDs in the csc formats into a contiguous range from
0 and return the original row node IDs per type.
Note that
1. The column (destination) IDs are included in the relabeled row IDs.
2. If there are repeated row IDs, they would not be uniqued and will be
treated as different nodes.
3. If `dst_timestamps` is given, the timestamp of each destination node will
be broadcasted to its corresponding source nodes.
Parameters Parameters
---------- ----------
...@@ -323,33 +342,75 @@ def compact_csc_format( ...@@ -323,33 +342,75 @@ def compact_csc_format(
- If `dst_nodes` is a dictionary: The keys are node type and the - If `dst_nodes` is a dictionary: The keys are node type and the
values are corresponding nodes. And IDs inside are heterogeneous ids. values are corresponding nodes. And IDs inside are heterogeneous ids.
dst_timestamps: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]
Timestamps of all destination nodes in the csc formats.
If given, the timestamp of each destination node will be broadcasted
to its corresponding source nodes.
Returns Returns
------- -------
Tuple[original_row_node_ids, compacted_csc_formats] Tuple[original_row_node_ids, compacted_csc_formats, ...]
A tensor of original row node IDs (per type) of all nodes in the input.
The compacted CSC formats, where node IDs are replaced with mapped node The compacted CSC formats, where node IDs are replaced with mapped node
IDs, and all nodes (per type). IDs ranging from 0 to N.
"Compacted CSC formats" indicates that the node IDs in the input node The source timestamps (per type) of all nodes in the input if `dst_timestamps` is given.
pairs are replaced with mapped node IDs, where each type of node is
mapped to a contiguous space of IDs ranging from 0 to N.
Examples Examples
-------- --------
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> N1 = torch.LongTensor([1, 2, 2]) >>> csc_formats = {
>>> N2 = torch.LongTensor([5, 6, 5]) ... "n2:e2:n1": gb.CSCFormatBase(
>>> csc_formats = {"n2:e2:n1": gb.CSCFormatBase(indptr=torch.tensor([0, 1]), ... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6])
... indices=torch.tensor([5]))} ... ),
>>> dst_nodes = {"n1": N1[:1]} ... "n1:e1:n1": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3])
... ),
... }
>>> dst_nodes = {"n1": torch.LongTensor([2, 4])}
>>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format( >>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(
... csc_formats, dst_nodes ... csc_formats, dst_nodes
... ) ... )
>>> print(original_row_node_ids) >>> original_row_node_ids
{'n1': tensor([1]), 'n2': tensor([5])} {'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])}
>>> print(compacted_csc_formats) >>> compacted_csc_formats
{"n2:e2:n1": CSCFormatBase(indptr=tensor([0, 1]), {'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
... indices=tensor([0]))} indices=tensor([0, 1, 2]),
), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
indices=tensor([2, 3, 4]),
)}
>>> csc_formats = {
... "n2:e2:n1": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6])
... ),
... "n1:e1:n1": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3])
... ),
... }
>>> dst_nodes = {"n1": torch.LongTensor([2, 4])}
>>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(
... csc_formats, dst_nodes
... )
>>> original_row_node_ids
{'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])}
>>> compacted_csc_formats
{'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
indices=tensor([0, 1, 2]),
), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
indices=tensor([2, 3, 4]),
)}
>>> dst_timestamps = {"n1": torch.LongTensor([10, 20])}
>>> (
... original_row_node_ids,
... compacted_csc_formats,
... src_timestamps,
... ) = gb.compact_csc_format(csc_formats, dst_nodes, dst_timestamps)
>>> src_timestamps
{'n1': tensor([10, 20, 10, 20, 20]), 'n2': tensor([10, 20, 20])}
""" """
is_homogeneous = not isinstance(csc_formats, dict) is_homogeneous = not isinstance(csc_formats, dict)
has_timestamp = dst_timestamps is not None
if is_homogeneous: if is_homogeneous:
if dst_nodes is not None: if dst_nodes is not None:
assert isinstance( assert isinstance(
...@@ -371,9 +432,18 @@ def compact_csc_format( ...@@ -371,9 +432,18 @@ def compact_csc_format(
+ offset + offset
), ),
) )
src_timestamps = None
if has_timestamp:
src_timestamps = _broadcast_timestamps(
compacted_csc_formats, dst_timestamps
)
else: else:
compacted_csc_formats = {} compacted_csc_formats = {}
src_timestamps = None
original_row_ids = copy.deepcopy(dst_nodes) original_row_ids = copy.deepcopy(dst_nodes)
if has_timestamp:
src_timestamps = copy.deepcopy(dst_timestamps)
for etype, csc_format in csc_formats.items(): for etype, csc_format in csc_formats.items():
src_type, _, dst_type = etype_str_to_tuple(etype) src_type, _, dst_type = etype_str_to_tuple(etype)
assert len(dst_nodes.get(dst_type, [])) + 1 == len( assert len(dst_nodes.get(dst_type, [])) + 1 == len(
...@@ -406,4 +476,22 @@ def compact_csc_format( ...@@ -406,4 +476,22 @@ def compact_csc_format(
+ offset + offset
), ),
) )
if has_timestamp:
# If destination timestamps are given, broadcast them to the
# corresponding source nodes.
src_timestamps[src_type] = torch.cat(
(
src_timestamps.get(
src_type,
torch.tensor(
[], dtype=dst_timestamps[dst_type].dtype
),
),
_broadcast_timestamps(
csc_format, dst_timestamps[dst_type]
),
)
)
if has_timestamp:
return original_row_ids, compacted_csc_formats, src_timestamps
return original_row_ids, compacted_csc_formats return original_row_ids, compacted_csc_formats
...@@ -137,7 +137,7 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -137,7 +137,7 @@ class SubgraphSampler(MiniBatchTransformer):
compacted_negative_dsts if has_neg_dst else None, compacted_negative_dsts if has_neg_dst else None,
) )
def sample_subgraphs(self, seeds): def sample_subgraphs(self, seeds, seeds_timestamp=None):
"""Sample subgraphs from the given seeds. """Sample subgraphs from the given seeds.
Any subclass of SubgraphSampler should implement this method. Any subclass of SubgraphSampler should implement this method.
......
...@@ -911,7 +911,7 @@ def test_temporal_sample_neighbors_homo( ...@@ -911,7 +911,7 @@ def test_temporal_sample_neighbors_homo(
return available_neighbors return available_neighbors
nodes = torch.tensor(seed_list, dtype=indices_dtype) nodes = torch.tensor(seed_list, dtype=indices_dtype)
subgraph, neighbors_timestamp = sampler( subgraph = sampler(
nodes, nodes,
seed_timestamp, seed_timestamp,
fanouts, fanouts,
...@@ -1004,7 +1004,7 @@ def test_temporal_sample_neighbors_hetero( ...@@ -1004,7 +1004,7 @@ def test_temporal_sample_neighbors_hetero(
) )
graph.edge_attributes = {"timestamp": edge_timestamp} graph.edge_attributes = {"timestamp": edge_timestamp}
subgraph, neighbors_timestamp = sampler( subgraph = sampler(
seeds, seeds,
seed_timestamp, seed_timestamp,
fanouts, fanouts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment