Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f5981789
"src/vscode:/vscode.git/clone" did not exist on "f69511ecc618330212e7148265e1c0323d2fa5cf"
Unverified
Commit
f5981789
authored
Jan 03, 2024
by
czkkkkkk
Committed by
GitHub
Jan 03, 2024
Browse files
[Graphbolt] Support temporal sampling in SubgraphSampler. (#6846)
parent
d32a5980
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
350 additions
and
72 deletions
+350
-72
python/dgl/graphbolt/impl/__init__.py
python/dgl/graphbolt/impl/__init__.py
+1
-0
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+2
-2
python/dgl/graphbolt/impl/in_subgraph_sampler.py
python/dgl/graphbolt/impl/in_subgraph_sampler.py
+1
-1
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+1
-1
python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
+6
-3
python/dgl/graphbolt/internal/sample_utils.py
python/dgl/graphbolt/internal/sample_utils.py
+64
-3
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+59
-6
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+216
-56
No files found.
python/dgl/graphbolt/impl/__init__.py
View file @
f5981789
...
...
@@ -5,6 +5,7 @@ from .gpu_cached_feature import *
from
.in_subgraph_sampler
import
*
from
.legacy_dataset
import
*
from
.neighbor_sampler
import
*
from
.temporal_neighbor_sampler
import
*
from
.ondisk_dataset
import
*
from
.ondisk_metadata
import
*
from
.sampled_subgraph_impl
import
*
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
f5981789
...
...
@@ -761,8 +761,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
def
temporal_sample_neighbors
(
self
,
nodes
:
torch
.
Tensor
,
input_nodes_timestamp
:
torch
.
Tensor
,
nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]],
input_nodes_timestamp
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]],
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
probs_name
:
Optional
[
str
]
=
None
,
...
...
python/dgl/graphbolt/impl/in_subgraph_sampler.py
View file @
f5981789
...
...
@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler):
self
.
graph
=
graph
self
.
sampler
=
graph
.
in_subgraph
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
subgraph
=
self
.
sampler
(
seeds
)
(
original_row_node_ids
,
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
f5981789
...
...
@@ -117,7 +117,7 @@ class NeighborSampler(SubgraphSampler):
self
.
deduplicate
=
deduplicate
self
.
sampler
=
graph
.
sample_neighbors
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
subgraphs
=
[]
num_layers
=
len
(
self
.
fanouts
)
# Enrich seeds with all node types.
...
...
python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
View file @
f5981789
...
...
@@ -89,7 +89,10 @@ class TemporalNeighborSampler(SubgraphSampler):
self
.
edge_timestamp_attr_name
=
edge_timestamp_attr_name
self
.
sampler
=
graph
.
temporal_sample_neighbors
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
assert
(
seeds_timestamp
is
not
None
),
"seeds_timestamp must be provided for temporal neighbor sampling."
subgraphs
=
[]
num_layers
=
len
(
self
.
fanouts
)
# Enrich seeds with all node types.
...
...
@@ -117,10 +120,10 @@ class TemporalNeighborSampler(SubgraphSampler):
original_row_node_ids
,
compacted_csc_formats
,
row_timestamps
,
)
=
compact_csc_format
(
subgraph
.
node_pairs
,
seeds
,
seeds_timestamp
)
)
=
compact_csc_format
(
subgraph
.
sampled_csc
,
seeds
,
seeds_timestamp
)
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
compacted_csc_formats
,
sampled_csc
=
compacted_csc_formats
,
original_column_node_ids
=
seeds
,
original_row_node_ids
=
original_row_node_ids
,
original_edge_ids
=
subgraph
.
original_edge_ids
,
...
...
python/dgl/graphbolt/internal/sample_utils.py
View file @
f5981789
...
...
@@ -61,6 +61,61 @@ def unique_and_compact(
return
unique_and_compact_per_type
(
nodes
)
def
compact_temporal_nodes
(
nodes
,
nodes_timestamp
):
"""Compact a list of temporal nodes without unique.
Note that since there is no unique, the nodes and nodes_timestamp are simply
concatenated. And the compacted nodes are consecutive numbers starting from
0.
Parameters
----------
nodes : List[torch.Tensor] or Dict[str, List[torch.Tensor]]
List of nodes for compacting.
the compact operator will be done per type
- If `nodes` is a list of tensor: All the tensors will compact together,
usually it is used for homogeneous graph.
- If `nodes` is a list of dictionary: The keys should be node type and
the values should be corresponding nodes, the compact will be done per
type, usually it is used for heterogeneous graph.
nodes_timestamp : List[torch.Tensor] or Dict[str, List[torch.Tensor]]
List of timestamps for compacting.
Returns
-------
Tuple[nodes, nodes_timestamp, compacted_node_list]
The concatenated nodes and nodes_timestamp, and the compacted nodes list,
where IDs inside are replaced with compacted node IDs.
"""
def
_compact_per_type
(
per_type_nodes
,
per_type_nodes_timestamp
):
nums
=
[
node
.
size
(
0
)
for
node
in
per_type_nodes
]
per_type_nodes
=
torch
.
cat
(
per_type_nodes
)
per_type_nodes_timestamp
=
torch
.
cat
(
per_type_nodes_timestamp
)
compacted_nodes
=
torch
.
arange
(
0
,
per_type_nodes
.
numel
(),
dtype
=
per_type_nodes
.
dtype
,
device
=
per_type_nodes
.
device
,
)
compacted_nodes
=
list
(
compacted_nodes
.
split
(
nums
))
return
per_type_nodes
,
per_type_nodes_timestamp
,
compacted_nodes
if
isinstance
(
nodes
,
dict
):
ret_nodes
,
ret_timestamp
,
compacted
=
{},
{},
{}
for
ntype
,
nodes_of_type
in
nodes
.
items
():
(
ret_nodes
[
ntype
],
ret_timestamp
[
ntype
],
compacted
[
ntype
],
)
=
_compact_per_type
(
nodes_of_type
,
nodes_timestamp
[
ntype
])
return
ret_nodes
,
ret_timestamp
,
compacted
else
:
return
_compact_per_type
(
nodes
,
nodes_timestamp
)
def
unique_and_compact_csc_formats
(
csc_formats
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...
...
@@ -236,7 +291,8 @@ def compact_csc_format(
A tensor of original row node IDs (per type) of all nodes in the input.
The compacted CSC formats, where node IDs are replaced with mapped node
IDs ranging from 0 to N.
The source timestamps (per type) of all nodes in the input if `dst_timestamps` is given.
The source timestamps (per type) of all nodes in the input if
`dst_timestamps` is given.
Examples
--------
...
...
@@ -318,8 +374,13 @@ def compact_csc_format(
src_timestamps
=
None
if
has_timestamp
:
src_timestamps
=
_broadcast_timestamps
(
src_timestamps
=
torch
.
cat
(
[
dst_timestamps
,
_broadcast_timestamps
(
compacted_csc_formats
,
dst_timestamps
),
]
)
else
:
compacted_csc_formats
=
{}
...
...
python/dgl/graphbolt/subgraph_sampler.py
View file @
f5981789
...
...
@@ -6,7 +6,7 @@ from typing import Dict
from
torch.utils.data
import
functional_datapipe
from
.base
import
etype_str_to_tuple
from
.internal
import
unique_and_compact
from
.internal
import
compact_temporal_nodes
,
unique_and_compact
from
.minibatch_transformer
import
MiniBatchTransformer
__all__
=
[
...
...
@@ -40,12 +40,16 @@ class SubgraphSampler(MiniBatchTransformer):
if
minibatch
.
node_pairs
is
not
None
:
(
seeds
,
seeds_timestamp
,
minibatch
.
compacted_node_pairs
,
minibatch
.
compacted_negative_srcs
,
minibatch
.
compacted_negative_dsts
,
)
=
self
.
_node_pairs_preprocess
(
minibatch
)
elif
minibatch
.
seed_nodes
is
not
None
:
seeds
=
minibatch
.
seed_nodes
seeds_timestamp
=
(
minibatch
.
timestamp
if
hasattr
(
minibatch
,
"timestamp"
)
else
None
)
else
:
raise
ValueError
(
f
"Invalid minibatch
{
minibatch
}
: Either `node_pairs` or "
...
...
@@ -54,10 +58,11 @@ class SubgraphSampler(MiniBatchTransformer):
(
minibatch
.
input_nodes
,
minibatch
.
sampled_subgraphs
,
)
=
self
.
sample_subgraphs
(
seeds
)
)
=
self
.
sample_subgraphs
(
seeds
,
seeds_timestamp
)
return
minibatch
def
_node_pairs_preprocess
(
self
,
minibatch
):
use_timestamp
=
hasattr
(
minibatch
,
"timestamp"
)
node_pairs
=
minibatch
.
node_pairs
neg_src
,
neg_dst
=
minibatch
.
negative_srcs
,
minibatch
.
negative_dsts
has_neg_src
=
neg_src
is
not
None
...
...
@@ -72,20 +77,44 @@ class SubgraphSampler(MiniBatchTransformer):
)
# Collect nodes from all types of input.
nodes
=
defaultdict
(
list
)
nodes_timestamp
=
None
if
use_timestamp
:
nodes_timestamp
=
defaultdict
(
list
)
for
etype
,
(
src
,
dst
)
in
node_pairs
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
nodes
[
src_type
].
append
(
src
)
nodes
[
dst_type
].
append
(
dst
)
if
use_timestamp
:
nodes_timestamp
[
src_type
].
append
(
minibatch
.
timestamp
[
etype
])
nodes_timestamp
[
dst_type
].
append
(
minibatch
.
timestamp
[
etype
])
if
has_neg_src
:
for
etype
,
src
in
neg_src
.
items
():
src_type
,
_
,
_
=
etype_str_to_tuple
(
etype
)
nodes
[
src_type
].
append
(
src
.
view
(
-
1
))
if
use_timestamp
:
nodes_timestamp
[
src_type
].
append
(
minibatch
.
timestamp
[
etype
].
repeat_interleave
(
src
.
shape
[
-
1
]
)
)
if
has_neg_dst
:
for
etype
,
dst
in
neg_dst
.
items
():
_
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
nodes
[
dst_type
].
append
(
dst
.
view
(
-
1
))
if
use_timestamp
:
nodes_timestamp
[
dst_type
].
append
(
minibatch
.
timestamp
[
etype
].
repeat_interleave
(
dst
.
shape
[
-
1
]
)
)
# Unique and compact the collected nodes.
if
use_timestamp
:
seeds
,
nodes_timestamp
,
compacted
=
compact_temporal_nodes
(
nodes
,
nodes_timestamp
)
else
:
seeds
,
compacted
=
unique_and_compact
(
nodes
)
nodes_timestamp
=
None
(
compacted_node_pairs
,
compacted_negative_srcs
,
...
...
@@ -108,12 +137,30 @@ class SubgraphSampler(MiniBatchTransformer):
else
:
# Collect nodes from all types of input.
nodes
=
list
(
node_pairs
)
nodes_timestamp
=
None
if
use_timestamp
:
# Timestamp for source and destination nodes are the same.
nodes_timestamp
=
[
minibatch
.
timestamp
,
minibatch
.
timestamp
]
if
has_neg_src
:
nodes
.
append
(
neg_src
.
view
(
-
1
))
if
use_timestamp
:
nodes_timestamp
.
append
(
minibatch
.
timestamp
.
repeat_interleave
(
neg_src
.
shape
[
-
1
])
)
if
has_neg_dst
:
nodes
.
append
(
neg_dst
.
view
(
-
1
))
if
use_timestamp
:
nodes_timestamp
.
append
(
minibatch
.
timestamp
.
repeat_interleave
(
neg_dst
.
shape
[
-
1
])
)
# Unique and compact the collected nodes.
if
use_timestamp
:
seeds
,
nodes_timestamp
,
compacted
=
compact_temporal_nodes
(
nodes
,
nodes_timestamp
)
else
:
seeds
,
compacted
=
unique_and_compact
(
nodes
)
nodes_timestamp
=
None
# Map back in same order as collect.
compacted_node_pairs
=
tuple
(
compacted
[:
2
])
compacted
=
compacted
[
2
:]
...
...
@@ -132,13 +179,14 @@ class SubgraphSampler(MiniBatchTransformer):
)
return
(
seeds
,
nodes_timestamp
,
compacted_node_pairs
,
compacted_negative_srcs
if
has_neg_src
else
None
,
compacted_negative_dsts
if
has_neg_dst
else
None
,
)
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
"""Sample subgraphs from the given seeds.
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
"""Sample subgraphs from the given seeds
, possibly with temporal constraints
.
Any subclass of SubgraphSampler should implement this method.
...
...
@@ -147,6 +195,11 @@ class SubgraphSampler(MiniBatchTransformer):
seeds : Union[torch.Tensor, Dict[str, torch.Tensor]]
The seed nodes.
seeds_timestamp : Union[torch.Tensor, Dict[str, torch.Tensor]]
The timestamps of the seed nodes. If given, the sampled subgraphs
should not contain any nodes or edges that are newer than the
timestamps of the seed nodes. Default: None.
Returns
-------
Union[torch.Tensor, Dict[str, torch.Tensor]]
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
f5981789
import
unittest
from
enum
import
Enum
from
functools
import
partial
import
backend
as
F
...
...
@@ -12,6 +14,31 @@ from torchdata.datapipes.iter import Mapper
from
.
import
gb_test_utils
# Skip all tests on GPU.
pytestmark
=
pytest
.
mark
.
skipif
(
F
.
_default_context_str
!=
"cpu"
,
reason
=
"GraphBolt sampling tests are only supported on CPU."
,
)
class
SamplerType
(
Enum
):
Normal
=
0
Layer
=
1
Temporal
=
2
def
_get_sampler
(
sampler_type
):
if
sampler_type
==
SamplerType
.
Normal
:
return
gb
.
NeighborSampler
if
sampler_type
==
SamplerType
.
Layer
:
return
gb
.
LayerNeighborSampler
return
partial
(
gb
.
TemporalNeighborSampler
,
node_timestamp_attr_name
=
"timestamp"
,
edge_timestamp_attr_name
=
"timestamp"
,
)
def
test_SubgraphSampler_invoke
():
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
...
...
@@ -76,17 +103,29 @@ def test_NeighborSampler_fanouts(labor):
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Node
(
labor
):
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Node
(
sampler_type
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
items
=
torch
.
arange
(
10
)
names
=
"seed_nodes"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
20
).
to
(
F
.
ctx
())}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
len
(
graph
.
indices
)).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
arange
(
10
))
names
=
(
"seed_nodes"
,
"timestamp"
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
sampler_dp
=
S
ampler
(
item_sampler
,
graph
,
fanouts
)
s
ampler
=
_get_sampler
(
sampler_type
)
sampler_dp
=
s
ampler
(
item_sampler
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
5
...
...
@@ -95,33 +134,57 @@ def to_link_batch(data):
return
block
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link
(
labor
):
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Link
(
sampler_type
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
items
=
torch
.
arange
(
20
).
reshape
(
-
1
,
2
)
names
=
"node_pairs"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
20
).
to
(
F
.
ctx
())}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
len
(
graph
.
indices
)).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
arange
(
10
))
names
=
(
"node_pairs"
,
"timestamp"
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
datapipe
=
S
ampler
(
datapipe
,
graph
,
fanouts
)
s
ampler
=
_get_sampler
(
sampler_type
)
datapipe
=
s
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_With_Negative
(
labor
):
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Link_With_Negative
(
sampler_type
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
items
=
torch
.
arange
(
20
).
reshape
(
-
1
,
2
)
names
=
"node_pairs"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
20
).
to
(
F
.
ctx
())}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
len
(
graph
.
indices
)).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
arange
(
10
))
names
=
(
"node_pairs"
,
"timestamp"
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
datapipe
=
S
ampler
(
datapipe
,
graph
,
fanouts
)
s
ampler
=
_get_sampler
(
sampler_type
)
datapipe
=
s
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
...
...
@@ -148,34 +211,64 @@ def get_hetero_graph():
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Node_Hetero
(
labor
):
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Node_Hetero
(
sampler_type
):
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
itemset
=
gb
.
ItemSetDict
(
{
"n2"
:
gb
.
ItemSet
(
torch
.
arange
(
3
),
names
=
"seed_nodes"
)}
)
items
=
torch
.
arange
(
3
)
names
=
"seed_nodes"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
randint
(
0
,
10
,
(
3
,)))
names
=
(
names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
({
"n2"
:
gb
.
ItemSet
(
items
,
names
=
names
)})
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
sampler_dp
=
S
ampler
(
item_sampler
,
graph
,
fanouts
)
s
ampler
=
_get_sampler
(
sampler_type
)
sampler_dp
=
s
ampler
(
item_sampler
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
2
for
minibatch
in
sampler_dp
:
assert
len
(
minibatch
.
sampled_subgraphs
)
==
num_layer
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_Hetero
(
labor
):
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Link_Hetero
(
sampler_type
):
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
first_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
first_names
=
"node_pairs"
second_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
second_names
=
"node_pairs"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
first_items
=
(
first_items
,
torch
.
randint
(
0
,
10
,
(
4
,)))
first_names
=
(
first_names
,
"timestamp"
)
second_items
=
(
second_items
,
torch
.
randint
(
0
,
10
,
(
6
,)))
second_names
=
(
second_names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
(
{
"n1:e1:n2"
:
gb
.
ItemSet
(
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
,
names
=
"node_pairs"
,
first_items
,
names
=
first_names
,
),
"n2:e2:n1"
:
gb
.
ItemSet
(
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
,
names
=
"node_pairs"
,
second_items
,
names
=
second_names
,
),
}
)
...
...
@@ -183,24 +276,42 @@ def test_SubgraphSampler_Link_Hetero(labor):
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
datapipe
=
S
ampler
(
datapipe
,
graph
,
fanouts
)
s
ampler
=
_get_sampler
(
sampler_type
)
datapipe
=
s
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_Hetero_With_Negative
(
labor
):
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Link_Hetero_With_Negative
(
sampler_type
):
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
first_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
first_names
=
"node_pairs"
second_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
second_names
=
"node_pairs"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
first_items
=
(
first_items
,
torch
.
randint
(
0
,
10
,
(
4
,)))
first_names
=
(
first_names
,
"timestamp"
)
second_items
=
(
second_items
,
torch
.
randint
(
0
,
10
,
(
6
,)))
second_names
=
(
second_names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
(
{
"n1:e1:n2"
:
gb
.
ItemSet
(
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
,
names
=
"node_pairs"
,
first_items
,
names
=
first_names
,
),
"n2:e2:n1"
:
gb
.
ItemSet
(
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
,
names
=
"node_pairs"
,
second_items
,
names
=
second_names
,
),
}
)
...
...
@@ -209,8 +320,8 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
datapipe
=
S
ampler
(
datapipe
,
graph
,
fanouts
)
s
ampler
=
_get_sampler
(
sampler_type
)
datapipe
=
s
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
...
...
@@ -219,8 +330,11 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Sampling with replacement not yet supported on GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Random_Hetero_Graph
(
labor
):
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Random_Hetero_Graph
(
sampler_type
):
num_nodes
=
5
num_edges
=
9
num_ntypes
=
3
...
...
@@ -235,10 +349,14 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
)
=
gb_test_utils
.
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
node_attributes
=
{}
edge_attributes
=
{
"A1"
:
torch
.
randn
(
num_edges
),
"A2"
:
torch
.
randn
(
num_edges
),
}
if
sampler_type
==
SamplerType
.
Temporal
:
node_attributes
[
"timestamp"
]
=
torch
.
randint
(
0
,
10
,
(
num_nodes
,))
edge_attributes
[
"timestamp"
]
=
torch
.
randint
(
0
,
10
,
(
num_edges
,))
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
indices
,
...
...
@@ -246,21 +364,31 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
type_per_edge
=
type_per_edge
,
node_type_to_id
=
node_type_to_id
,
edge_type_to_id
=
edge_type_to_id
,
node_attributes
=
node_attributes
,
edge_attributes
=
edge_attributes
,
).
to
(
F
.
ctx
())
first_items
=
torch
.
tensor
([
0
])
first_names
=
"seed_nodes"
second_items
=
torch
.
tensor
([
0
])
second_names
=
"seed_nodes"
if
sampler_type
==
SamplerType
.
Temporal
:
first_items
=
(
first_items
,
torch
.
randint
(
0
,
10
,
(
1
,)))
first_names
=
(
first_names
,
"timestamp"
)
second_items
=
(
second_items
,
torch
.
randint
(
0
,
10
,
(
1
,)))
second_names
=
(
second_names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
(
{
"n2"
:
gb
.
ItemSet
(
torch
.
tensor
([
0
]),
names
=
"seed_nod
es
"
),
"n1"
:
gb
.
ItemSet
(
torch
.
tensor
([
0
])
,
names
=
"
se
e
d_n
od
es
"
),
"n2"
:
gb
.
ItemSet
(
first_items
,
names
=
first_nam
es
),
"n1"
:
gb
.
ItemSet
(
second_items
,
names
=
se
con
d_n
am
es
),
}
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
s
ampler
=
_get_sampler
(
sampler_type
)
sampler_dp
=
S
ampler
(
item_sampler
,
graph
,
fanouts
,
replace
=
True
)
sampler_dp
=
s
ampler
(
item_sampler
,
graph
,
fanouts
,
replace
=
True
)
for
data
in
sampler_dp
:
for
sampledsubgraph
in
data
.
sampled_subgraphs
:
...
...
@@ -289,23 +417,40 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Fails due to randomness on the GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_without_dedpulication_Homo
(
labor
):
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_without_dedpulication_Homo
(
sampler_type
):
graph
=
dgl
.
graph
(
([
5
,
0
,
1
,
5
,
6
,
7
,
2
,
2
,
4
],
[
0
,
1
,
2
,
2
,
2
,
2
,
3
,
4
,
4
])
)
graph
=
gb
.
from_dglgraph
(
graph
,
True
).
to
(
F
.
ctx
())
seed_nodes
=
torch
.
LongTensor
([
0
,
3
,
4
])
items
=
seed_nodes
names
=
"seed_nodes"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
randint
(
0
,
10
,
(
3
,)))
names
=
(
names
,
"timestamp"
)
itemset
=
gb
.
ItemSet
(
seed_nodes
,
names
=
"seed_nod
es
"
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
nam
es
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
len
(
seed_nodes
)).
copy_to
(
F
.
ctx
()
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
datapipe
=
Sampler
(
item_sampler
,
graph
,
fanouts
,
deduplicate
=
False
)
sampler
=
_get_sampler
(
sampler_type
)
if
sampler_type
==
SamplerType
.
Temporal
:
datapipe
=
sampler
(
item_sampler
,
graph
,
fanouts
)
else
:
datapipe
=
sampler
(
item_sampler
,
graph
,
fanouts
,
deduplicate
=
False
)
length
=
[
17
,
7
]
compacted_indices
=
[
...
...
@@ -334,17 +479,32 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_without_dedpulication_Hetero
(
labor
):
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_without_dedpulication_Hetero
(
sampler_type
):
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
itemset
=
gb
.
ItemSetDict
(
{
"n2"
:
gb
.
ItemSet
(
torch
.
arange
(
2
),
names
=
"seed_nodes"
)}
)
items
=
torch
.
arange
(
2
)
names
=
"seed_nodes"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
randint
(
0
,
10
,
(
2
,)))
names
=
(
names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
({
"n2"
:
gb
.
ItemSet
(
items
,
names
=
names
)})
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
datapipe
=
Sampler
(
item_sampler
,
graph
,
fanouts
,
deduplicate
=
False
)
sampler
=
_get_sampler
(
sampler_type
)
if
sampler_type
==
SamplerType
.
Temporal
:
datapipe
=
sampler
(
item_sampler
,
graph
,
fanouts
)
else
:
datapipe
=
sampler
(
item_sampler
,
graph
,
fanouts
,
deduplicate
=
False
)
csc_formats
=
[
{
"n1:e1:n2"
:
gb
.
CSCFormatBase
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment