Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3200b88b
Unverified
Commit
3200b88b
authored
Dec 30, 2023
by
czkkkkkk
Committed by
GitHub
Dec 30, 2023
Browse files
[Graphbolt] Add TemporalNeighborSampler. (#6814)
parent
f758c7c1
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
243 additions
and
31 deletions
+243
-31
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+1
-9
python/dgl/graphbolt/impl/in_subgraph_sampler.py
python/dgl/graphbolt/impl/in_subgraph_sampler.py
+1
-1
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+1
-1
python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
+132
-0
python/dgl/graphbolt/internal/sample_utils.py
python/dgl/graphbolt/internal/sample_utils.py
+105
-17
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+1
-1
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+2
-2
No files found.
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
3200b88b
...
...
@@ -867,15 +867,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
node_timestamp_attr_name
,
edge_timestamp_attr_name
,
)
# Broadcast the input nodes' timestamp to the sampled neighbors.
sampled_count
=
torch
.
diff
(
C_sampled_subgraph
.
indptr
)
neighbors_timestamp
=
input_nodes_timestamp
.
repeat_interleave
(
sampled_count
)
return
(
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
),
neighbors_timestamp
,
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
def
sample_negative_edges_uniform
(
self
,
edge_type
,
node_pairs
,
negative_ratio
...
...
python/dgl/graphbolt/impl/in_subgraph_sampler.py
View file @
3200b88b
...
...
@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler):
self
.
graph
=
graph
self
.
sampler
=
graph
.
in_subgraph
def
sample_subgraphs
(
self
,
seeds
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
subgraph
=
self
.
sampler
(
seeds
)
(
original_row_node_ids
,
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
3200b88b
...
...
@@ -112,7 +112,7 @@ class NeighborSampler(SubgraphSampler):
self
.
deduplicate
=
deduplicate
self
.
sampler
=
graph
.
sample_neighbors
def
sample_subgraphs
(
self
,
seeds
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
subgraphs
=
[]
num_layers
=
len
(
self
.
fanouts
)
# Enrich seeds with all node types.
...
...
python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
0 → 100644
View file @
3200b88b
"""Temporal neighbor subgraph samplers for GraphBolt."""
import
torch
from
torch.utils.data
import
functional_datapipe
from
..internal
import
compact_csc_format
from
..subgraph_sampler
import
SubgraphSampler
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
__all__
=
[
"TemporalNeighborSampler"
]
@
functional_datapipe
(
"temporal_sample_neighbor"
)
class
TemporalNeighborSampler
(
SubgraphSampler
):
"""Temporally sample neighbor edges from a graph and return sampled
subgraphs.
Functional name: :obj:`temporal_sample_neighbor`.
Neighbor sampler is responsible for sampling a subgraph from given data. It
returns an induced subgraph along with compacted information. In the
context of a node classification task, the neighbor sampler directly
utilizes the nodes provided as seed nodes. However, in scenarios involving
link prediction, the process needs another pre-peocess operation. That is,
gathering unique nodes from the given node pairs, encompassing both
positive and negative node pairs, and employs these nodes as the seed nodes
for subsequent steps.
Parameters
----------
datapipe : DataPipe
The datapipe.
graph : FusedCSCSamplingGraph
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor] or list[int]
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
Note: The fanout order is from the outermost layer to innermost layer.
For example, the fanout '[15, 10, 5]' means that 15 to the outermost
layer, 10 to the intermediate layer and 5 corresponds to the innermost
layer.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
prob_name: str, optional
The name of an edge attribute used as the weights of sampling for
each node. This attribute tensor should contain (unnormalized)
probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges.
node_timestamp_attr_name: str, optional
The name of an node attribute used as the timestamps of nodes.
It must be a 1D integer tensor, with the number of elements
equalling the total number of nodes.
edge_timestamp_attr_name: str, optional
The name of an edge attribute used as the timestamps of edges.
It must be a 1D integer tensor, with the number of elements
equalling the total number of edges.
Examples
-------
TODO(zhenkun) : Add an example after the API to pass timestamps is finalized.
"""
def
__init__
(
self
,
datapipe
,
graph
,
fanouts
,
replace
=
False
,
prob_name
=
None
,
node_timestamp_attr_name
=
None
,
edge_timestamp_attr_name
=
None
,
):
super
().
__init__
(
datapipe
)
self
.
graph
=
graph
# Convert fanouts to a list of tensors.
self
.
fanouts
=
[]
for
fanout
in
fanouts
:
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
fanout
=
torch
.
LongTensor
([
int
(
fanout
)])
self
.
fanouts
.
insert
(
0
,
fanout
)
self
.
replace
=
replace
self
.
prob_name
=
prob_name
self
.
node_timestamp_attr_name
=
node_timestamp_attr_name
self
.
edge_timestamp_attr_name
=
edge_timestamp_attr_name
self
.
sampler
=
graph
.
temporal_sample_neighbors
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
subgraphs
=
[]
num_layers
=
len
(
self
.
fanouts
)
# Enrich seeds with all node types.
if
isinstance
(
seeds
,
dict
):
ntypes
=
list
(
self
.
graph
.
node_type_to_id
.
keys
())
seeds
=
{
ntype
:
seeds
.
get
(
ntype
,
torch
.
LongTensor
([]))
for
ntype
in
ntypes
}
seeds_timestamp
=
{
ntype
:
seeds_timestamp
.
get
(
ntype
,
torch
.
LongTensor
([]))
for
ntype
in
ntypes
}
for
hop
in
range
(
num_layers
):
subgraph
=
self
.
sampler
(
seeds
,
seeds_timestamp
,
self
.
fanouts
[
hop
],
self
.
replace
,
self
.
prob_name
,
self
.
node_timestamp_attr_name
,
self
.
edge_timestamp_attr_name
,
)
(
original_row_node_ids
,
compacted_csc_formats
,
row_timestamps
,
)
=
compact_csc_format
(
subgraph
.
node_pairs
,
seeds
,
seeds_timestamp
)
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
compacted_csc_formats
,
original_column_node_ids
=
seeds
,
original_row_node_ids
=
original_row_node_ids
,
original_edge_ids
=
subgraph
.
original_edge_ids
,
)
subgraphs
.
insert
(
0
,
subgraph
)
seeds
=
original_row_node_ids
seeds_timestamp
=
row_timestamps
return
seeds
,
subgraphs
python/dgl/graphbolt/internal/sample_utils.py
View file @
3200b88b
...
...
@@ -2,7 +2,7 @@
import
copy
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -299,12 +299,31 @@ def unique_and_compact_csc_formats(
return
unique_nodes
,
compacted_csc_formats
def
_broadcast_timestamps
(
csc
,
dst_timestamps
):
"""Broadcast the timestamp of each destination node to its corresponding
source nodes."""
count
=
torch
.
diff
(
csc
.
indptr
)
src_timestamps
=
torch
.
repeat_interleave
(
dst_timestamps
,
count
)
return
src_timestamps
def
compact_csc_format
(
csc_formats
:
Union
[
CSCFormatBase
,
Dict
[
str
,
CSCFormatBase
]],
dst_nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]],
dst_timestamps
:
Optional
[
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
]
=
None
,
):
"""
Compact csc formats and return original_row_ids (per type).
Relabel the row (source) IDs in the csc formats into a contiguous range from
0 and return the original row node IDs per type.
Note that
1. The column (destination) IDs are included in the relabeled row IDs.
2. If there are repeated row IDs, they would not be uniqued and will be
treated as different nodes.
3. If `dst_timestamps` is given, the timestamp of each destination node will
be broadcasted to its corresponding source nodes.
Parameters
----------
...
...
@@ -323,33 +342,75 @@ def compact_csc_format(
- If `dst_nodes` is a dictionary: The keys are node type and the
values are corresponding nodes. And IDs inside are heterogeneous ids.
dst_timestamps: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]
Timestamps of all destination nodes in the csc formats.
If given, the timestamp of each destination node will be broadcasted
to its corresponding source nodes.
Returns
-------
Tuple[original_row_node_ids, compacted_csc_formats]
Tuple[original_row_node_ids, compacted_csc_formats, ...]
A tensor of original row node IDs (per type) of all nodes in the input.
The compacted CSC formats, where node IDs are replaced with mapped node
IDs, and all nodes (per type).
"Compacted CSC formats" indicates that the node IDs in the input node
pairs are replaced with mapped node IDs, where each type of node is
mapped to a contiguous space of IDs ranging from 0 to N.
IDs ranging from 0 to N.
The source timestamps (per type) of all nodes in the input if `dst_timestamps` is given.
Examples
--------
>>> import dgl.graphbolt as gb
>>> N1 = torch.LongTensor([1, 2, 2])
>>> N2 = torch.LongTensor([5, 6, 5])
>>> csc_formats = {"n2:e2:n1": gb.CSCFormatBase(indptr=torch.tensor([0, 1]),
... indices=torch.tensor([5]))}
>>> dst_nodes = {"n1": N1[:1]}
>>> csc_formats = {
... "n2:e2:n1": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6])
... ),
... "n1:e1:n1": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3])
... ),
... }
>>> dst_nodes = {"n1": torch.LongTensor([2, 4])}
>>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(
... csc_formats, dst_nodes
... )
>>> print(original_row_node_ids)
{'n1': tensor([1]), 'n2': tensor([5])}
>>> print(compacted_csc_formats)
{"n2:e2:n1": CSCFormatBase(indptr=tensor([0, 1]),
... indices=tensor([0]))}
>>> original_row_node_ids
{'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])}
>>> compacted_csc_formats
{'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
indices=tensor([0, 1, 2]),
), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
indices=tensor([2, 3, 4]),
)}
>>> csc_formats = {
... "n2:e2:n1": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6])
... ),
... "n1:e1:n1": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3])
... ),
... }
>>> dst_nodes = {"n1": torch.LongTensor([2, 4])}
>>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(
... csc_formats, dst_nodes
... )
>>> original_row_node_ids
{'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])}
>>> compacted_csc_formats
{'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
indices=tensor([0, 1, 2]),
), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
indices=tensor([2, 3, 4]),
)}
>>> dst_timestamps = {"n1": torch.LongTensor([10, 20])}
>>> (
... original_row_node_ids,
... compacted_csc_formats,
... src_timestamps,
... ) = gb.compact_csc_format(csc_formats, dst_nodes, dst_timestamps)
>>> src_timestamps
{'n1': tensor([10, 20, 10, 20, 20]), 'n2': tensor([10, 20, 20])}
"""
is_homogeneous
=
not
isinstance
(
csc_formats
,
dict
)
has_timestamp
=
dst_timestamps
is
not
None
if
is_homogeneous
:
if
dst_nodes
is
not
None
:
assert
isinstance
(
...
...
@@ -371,9 +432,18 @@ def compact_csc_format(
+
offset
),
)
src_timestamps
=
None
if
has_timestamp
:
src_timestamps
=
_broadcast_timestamps
(
compacted_csc_formats
,
dst_timestamps
)
else
:
compacted_csc_formats
=
{}
src_timestamps
=
None
original_row_ids
=
copy
.
deepcopy
(
dst_nodes
)
if
has_timestamp
:
src_timestamps
=
copy
.
deepcopy
(
dst_timestamps
)
for
etype
,
csc_format
in
csc_formats
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
assert
len
(
dst_nodes
.
get
(
dst_type
,
[]))
+
1
==
len
(
...
...
@@ -406,4 +476,22 @@ def compact_csc_format(
+
offset
),
)
if
has_timestamp
:
# If destination timestamps are given, broadcast them to the
# corresponding source nodes.
src_timestamps
[
src_type
]
=
torch
.
cat
(
(
src_timestamps
.
get
(
src_type
,
torch
.
tensor
(
[],
dtype
=
dst_timestamps
[
dst_type
].
dtype
),
),
_broadcast_timestamps
(
csc_format
,
dst_timestamps
[
dst_type
]
),
)
)
if
has_timestamp
:
return
original_row_ids
,
compacted_csc_formats
,
src_timestamps
return
original_row_ids
,
compacted_csc_formats
python/dgl/graphbolt/subgraph_sampler.py
View file @
3200b88b
...
...
@@ -137,7 +137,7 @@ class SubgraphSampler(MiniBatchTransformer):
compacted_negative_dsts
if
has_neg_dst
else
None
,
)
def
sample_subgraphs
(
self
,
seeds
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
"""Sample subgraphs from the given seeds.
Any subclass of SubgraphSampler should implement this method.
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
3200b88b
...
...
@@ -911,7 +911,7 @@ def test_temporal_sample_neighbors_homo(
return
available_neighbors
nodes
=
torch
.
tensor
(
seed_list
,
dtype
=
indices_dtype
)
subgraph
,
neighbors_timestamp
=
sampler
(
subgraph
=
sampler
(
nodes
,
seed_timestamp
,
fanouts
,
...
...
@@ -1004,7 +1004,7 @@ def test_temporal_sample_neighbors_hetero(
)
graph
.
edge_attributes
=
{
"timestamp"
:
edge_timestamp
}
subgraph
,
neighbors_timestamp
=
sampler
(
subgraph
=
sampler
(
seeds
,
seed_timestamp
,
fanouts
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment