Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f5981789
"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "e3efbc2d9094685dd2d4ae143853941f82f167af"
Unverified
Commit
f5981789
authored
Jan 03, 2024
by
czkkkkkk
Committed by
GitHub
Jan 03, 2024
Browse files
[Graphbolt] Support temporal sampling in SubgraphSampler. (#6846)
parent
d32a5980
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
350 additions
and
72 deletions
+350
-72
python/dgl/graphbolt/impl/__init__.py
python/dgl/graphbolt/impl/__init__.py
+1
-0
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+2
-2
python/dgl/graphbolt/impl/in_subgraph_sampler.py
python/dgl/graphbolt/impl/in_subgraph_sampler.py
+1
-1
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+1
-1
python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
+6
-3
python/dgl/graphbolt/internal/sample_utils.py
python/dgl/graphbolt/internal/sample_utils.py
+64
-3
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+59
-6
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+216
-56
No files found.
python/dgl/graphbolt/impl/__init__.py
View file @
f5981789
...
@@ -5,6 +5,7 @@ from .gpu_cached_feature import *
...
@@ -5,6 +5,7 @@ from .gpu_cached_feature import *
from
.in_subgraph_sampler
import
*
from
.in_subgraph_sampler
import
*
from
.legacy_dataset
import
*
from
.legacy_dataset
import
*
from
.neighbor_sampler
import
*
from
.neighbor_sampler
import
*
from
.temporal_neighbor_sampler
import
*
from
.ondisk_dataset
import
*
from
.ondisk_dataset
import
*
from
.ondisk_metadata
import
*
from
.ondisk_metadata
import
*
from
.sampled_subgraph_impl
import
*
from
.sampled_subgraph_impl
import
*
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
f5981789
...
@@ -761,8 +761,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -761,8 +761,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
def
temporal_sample_neighbors
(
def
temporal_sample_neighbors
(
self
,
self
,
nodes
:
torch
.
Tensor
,
nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]],
input_nodes_timestamp
:
torch
.
Tensor
,
input_nodes_timestamp
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]],
fanouts
:
torch
.
Tensor
,
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
replace
:
bool
=
False
,
probs_name
:
Optional
[
str
]
=
None
,
probs_name
:
Optional
[
str
]
=
None
,
...
...
python/dgl/graphbolt/impl/in_subgraph_sampler.py
View file @
f5981789
...
@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler):
...
@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler):
self
.
graph
=
graph
self
.
graph
=
graph
self
.
sampler
=
graph
.
in_subgraph
self
.
sampler
=
graph
.
in_subgraph
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
subgraph
=
self
.
sampler
(
seeds
)
subgraph
=
self
.
sampler
(
seeds
)
(
(
original_row_node_ids
,
original_row_node_ids
,
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
f5981789
...
@@ -117,7 +117,7 @@ class NeighborSampler(SubgraphSampler):
...
@@ -117,7 +117,7 @@ class NeighborSampler(SubgraphSampler):
self
.
deduplicate
=
deduplicate
self
.
deduplicate
=
deduplicate
self
.
sampler
=
graph
.
sample_neighbors
self
.
sampler
=
graph
.
sample_neighbors
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
subgraphs
=
[]
subgraphs
=
[]
num_layers
=
len
(
self
.
fanouts
)
num_layers
=
len
(
self
.
fanouts
)
# Enrich seeds with all node types.
# Enrich seeds with all node types.
...
...
python/dgl/graphbolt/impl/temporal_neighbor_sampler.py
View file @
f5981789
...
@@ -89,7 +89,10 @@ class TemporalNeighborSampler(SubgraphSampler):
...
@@ -89,7 +89,10 @@ class TemporalNeighborSampler(SubgraphSampler):
self
.
edge_timestamp_attr_name
=
edge_timestamp_attr_name
self
.
edge_timestamp_attr_name
=
edge_timestamp_attr_name
self
.
sampler
=
graph
.
temporal_sample_neighbors
self
.
sampler
=
graph
.
temporal_sample_neighbors
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
assert
(
seeds_timestamp
is
not
None
),
"seeds_timestamp must be provided for temporal neighbor sampling."
subgraphs
=
[]
subgraphs
=
[]
num_layers
=
len
(
self
.
fanouts
)
num_layers
=
len
(
self
.
fanouts
)
# Enrich seeds with all node types.
# Enrich seeds with all node types.
...
@@ -117,10 +120,10 @@ class TemporalNeighborSampler(SubgraphSampler):
...
@@ -117,10 +120,10 @@ class TemporalNeighborSampler(SubgraphSampler):
original_row_node_ids
,
original_row_node_ids
,
compacted_csc_formats
,
compacted_csc_formats
,
row_timestamps
,
row_timestamps
,
)
=
compact_csc_format
(
subgraph
.
node_pairs
,
seeds
,
seeds_timestamp
)
)
=
compact_csc_format
(
subgraph
.
sampled_csc
,
seeds
,
seeds_timestamp
)
subgraph
=
SampledSubgraphImpl
(
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
compacted_csc_formats
,
sampled_csc
=
compacted_csc_formats
,
original_column_node_ids
=
seeds
,
original_column_node_ids
=
seeds
,
original_row_node_ids
=
original_row_node_ids
,
original_row_node_ids
=
original_row_node_ids
,
original_edge_ids
=
subgraph
.
original_edge_ids
,
original_edge_ids
=
subgraph
.
original_edge_ids
,
...
...
python/dgl/graphbolt/internal/sample_utils.py
View file @
f5981789
...
@@ -61,6 +61,61 @@ def unique_and_compact(
...
@@ -61,6 +61,61 @@ def unique_and_compact(
return
unique_and_compact_per_type
(
nodes
)
return
unique_and_compact_per_type
(
nodes
)
def
compact_temporal_nodes
(
nodes
,
nodes_timestamp
):
"""Compact a list of temporal nodes without unique.
Note that since there is no unique, the nodes and nodes_timestamp are simply
concatenated. And the compacted nodes are consecutive numbers starting from
0.
Parameters
----------
nodes : List[torch.Tensor] or Dict[str, List[torch.Tensor]]
List of nodes for compacting.
the compact operator will be done per type
- If `nodes` is a list of tensor: All the tensors will compact together,
usually it is used for homogeneous graph.
- If `nodes` is a list of dictionary: The keys should be node type and
the values should be corresponding nodes, the compact will be done per
type, usually it is used for heterogeneous graph.
nodes_timestamp : List[torch.Tensor] or Dict[str, List[torch.Tensor]]
List of timestamps for compacting.
Returns
-------
Tuple[nodes, nodes_timestamp, compacted_node_list]
The concatenated nodes and nodes_timestamp, and the compacted nodes list,
where IDs inside are replaced with compacted node IDs.
"""
def
_compact_per_type
(
per_type_nodes
,
per_type_nodes_timestamp
):
nums
=
[
node
.
size
(
0
)
for
node
in
per_type_nodes
]
per_type_nodes
=
torch
.
cat
(
per_type_nodes
)
per_type_nodes_timestamp
=
torch
.
cat
(
per_type_nodes_timestamp
)
compacted_nodes
=
torch
.
arange
(
0
,
per_type_nodes
.
numel
(),
dtype
=
per_type_nodes
.
dtype
,
device
=
per_type_nodes
.
device
,
)
compacted_nodes
=
list
(
compacted_nodes
.
split
(
nums
))
return
per_type_nodes
,
per_type_nodes_timestamp
,
compacted_nodes
if
isinstance
(
nodes
,
dict
):
ret_nodes
,
ret_timestamp
,
compacted
=
{},
{},
{}
for
ntype
,
nodes_of_type
in
nodes
.
items
():
(
ret_nodes
[
ntype
],
ret_timestamp
[
ntype
],
compacted
[
ntype
],
)
=
_compact_per_type
(
nodes_of_type
,
nodes_timestamp
[
ntype
])
return
ret_nodes
,
ret_timestamp
,
compacted
else
:
return
_compact_per_type
(
nodes
,
nodes_timestamp
)
def
unique_and_compact_csc_formats
(
def
unique_and_compact_csc_formats
(
csc_formats
:
Union
[
csc_formats
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...
@@ -236,7 +291,8 @@ def compact_csc_format(
...
@@ -236,7 +291,8 @@ def compact_csc_format(
A tensor of original row node IDs (per type) of all nodes in the input.
A tensor of original row node IDs (per type) of all nodes in the input.
The compacted CSC formats, where node IDs are replaced with mapped node
The compacted CSC formats, where node IDs are replaced with mapped node
IDs ranging from 0 to N.
IDs ranging from 0 to N.
The source timestamps (per type) of all nodes in the input if `dst_timestamps` is given.
The source timestamps (per type) of all nodes in the input if
`dst_timestamps` is given.
Examples
Examples
--------
--------
...
@@ -318,8 +374,13 @@ def compact_csc_format(
...
@@ -318,8 +374,13 @@ def compact_csc_format(
src_timestamps
=
None
src_timestamps
=
None
if
has_timestamp
:
if
has_timestamp
:
src_timestamps
=
_broadcast_timestamps
(
src_timestamps
=
torch
.
cat
(
compacted_csc_formats
,
dst_timestamps
[
dst_timestamps
,
_broadcast_timestamps
(
compacted_csc_formats
,
dst_timestamps
),
]
)
)
else
:
else
:
compacted_csc_formats
=
{}
compacted_csc_formats
=
{}
...
...
python/dgl/graphbolt/subgraph_sampler.py
View file @
f5981789
...
@@ -6,7 +6,7 @@ from typing import Dict
...
@@ -6,7 +6,7 @@ from typing import Dict
from
torch.utils.data
import
functional_datapipe
from
torch.utils.data
import
functional_datapipe
from
.base
import
etype_str_to_tuple
from
.base
import
etype_str_to_tuple
from
.internal
import
unique_and_compact
from
.internal
import
compact_temporal_nodes
,
unique_and_compact
from
.minibatch_transformer
import
MiniBatchTransformer
from
.minibatch_transformer
import
MiniBatchTransformer
__all__
=
[
__all__
=
[
...
@@ -40,12 +40,16 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -40,12 +40,16 @@ class SubgraphSampler(MiniBatchTransformer):
if
minibatch
.
node_pairs
is
not
None
:
if
minibatch
.
node_pairs
is
not
None
:
(
(
seeds
,
seeds
,
seeds_timestamp
,
minibatch
.
compacted_node_pairs
,
minibatch
.
compacted_node_pairs
,
minibatch
.
compacted_negative_srcs
,
minibatch
.
compacted_negative_srcs
,
minibatch
.
compacted_negative_dsts
,
minibatch
.
compacted_negative_dsts
,
)
=
self
.
_node_pairs_preprocess
(
minibatch
)
)
=
self
.
_node_pairs_preprocess
(
minibatch
)
elif
minibatch
.
seed_nodes
is
not
None
:
elif
minibatch
.
seed_nodes
is
not
None
:
seeds
=
minibatch
.
seed_nodes
seeds
=
minibatch
.
seed_nodes
seeds_timestamp
=
(
minibatch
.
timestamp
if
hasattr
(
minibatch
,
"timestamp"
)
else
None
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid minibatch
{
minibatch
}
: Either `node_pairs` or "
f
"Invalid minibatch
{
minibatch
}
: Either `node_pairs` or "
...
@@ -54,10 +58,11 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -54,10 +58,11 @@ class SubgraphSampler(MiniBatchTransformer):
(
(
minibatch
.
input_nodes
,
minibatch
.
input_nodes
,
minibatch
.
sampled_subgraphs
,
minibatch
.
sampled_subgraphs
,
)
=
self
.
sample_subgraphs
(
seeds
)
)
=
self
.
sample_subgraphs
(
seeds
,
seeds_timestamp
)
return
minibatch
return
minibatch
def
_node_pairs_preprocess
(
self
,
minibatch
):
def
_node_pairs_preprocess
(
self
,
minibatch
):
use_timestamp
=
hasattr
(
minibatch
,
"timestamp"
)
node_pairs
=
minibatch
.
node_pairs
node_pairs
=
minibatch
.
node_pairs
neg_src
,
neg_dst
=
minibatch
.
negative_srcs
,
minibatch
.
negative_dsts
neg_src
,
neg_dst
=
minibatch
.
negative_srcs
,
minibatch
.
negative_dsts
has_neg_src
=
neg_src
is
not
None
has_neg_src
=
neg_src
is
not
None
...
@@ -72,20 +77,44 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -72,20 +77,44 @@ class SubgraphSampler(MiniBatchTransformer):
)
)
# Collect nodes from all types of input.
# Collect nodes from all types of input.
nodes
=
defaultdict
(
list
)
nodes
=
defaultdict
(
list
)
nodes_timestamp
=
None
if
use_timestamp
:
nodes_timestamp
=
defaultdict
(
list
)
for
etype
,
(
src
,
dst
)
in
node_pairs
.
items
():
for
etype
,
(
src
,
dst
)
in
node_pairs
.
items
():
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
nodes
[
src_type
].
append
(
src
)
nodes
[
src_type
].
append
(
src
)
nodes
[
dst_type
].
append
(
dst
)
nodes
[
dst_type
].
append
(
dst
)
if
use_timestamp
:
nodes_timestamp
[
src_type
].
append
(
minibatch
.
timestamp
[
etype
])
nodes_timestamp
[
dst_type
].
append
(
minibatch
.
timestamp
[
etype
])
if
has_neg_src
:
if
has_neg_src
:
for
etype
,
src
in
neg_src
.
items
():
for
etype
,
src
in
neg_src
.
items
():
src_type
,
_
,
_
=
etype_str_to_tuple
(
etype
)
src_type
,
_
,
_
=
etype_str_to_tuple
(
etype
)
nodes
[
src_type
].
append
(
src
.
view
(
-
1
))
nodes
[
src_type
].
append
(
src
.
view
(
-
1
))
if
use_timestamp
:
nodes_timestamp
[
src_type
].
append
(
minibatch
.
timestamp
[
etype
].
repeat_interleave
(
src
.
shape
[
-
1
]
)
)
if
has_neg_dst
:
if
has_neg_dst
:
for
etype
,
dst
in
neg_dst
.
items
():
for
etype
,
dst
in
neg_dst
.
items
():
_
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
_
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
nodes
[
dst_type
].
append
(
dst
.
view
(
-
1
))
nodes
[
dst_type
].
append
(
dst
.
view
(
-
1
))
if
use_timestamp
:
nodes_timestamp
[
dst_type
].
append
(
minibatch
.
timestamp
[
etype
].
repeat_interleave
(
dst
.
shape
[
-
1
]
)
)
# Unique and compact the collected nodes.
# Unique and compact the collected nodes.
seeds
,
compacted
=
unique_and_compact
(
nodes
)
if
use_timestamp
:
seeds
,
nodes_timestamp
,
compacted
=
compact_temporal_nodes
(
nodes
,
nodes_timestamp
)
else
:
seeds
,
compacted
=
unique_and_compact
(
nodes
)
nodes_timestamp
=
None
(
(
compacted_node_pairs
,
compacted_node_pairs
,
compacted_negative_srcs
,
compacted_negative_srcs
,
...
@@ -108,12 +137,30 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -108,12 +137,30 @@ class SubgraphSampler(MiniBatchTransformer):
else
:
else
:
# Collect nodes from all types of input.
# Collect nodes from all types of input.
nodes
=
list
(
node_pairs
)
nodes
=
list
(
node_pairs
)
nodes_timestamp
=
None
if
use_timestamp
:
# Timestamp for source and destination nodes are the same.
nodes_timestamp
=
[
minibatch
.
timestamp
,
minibatch
.
timestamp
]
if
has_neg_src
:
if
has_neg_src
:
nodes
.
append
(
neg_src
.
view
(
-
1
))
nodes
.
append
(
neg_src
.
view
(
-
1
))
if
use_timestamp
:
nodes_timestamp
.
append
(
minibatch
.
timestamp
.
repeat_interleave
(
neg_src
.
shape
[
-
1
])
)
if
has_neg_dst
:
if
has_neg_dst
:
nodes
.
append
(
neg_dst
.
view
(
-
1
))
nodes
.
append
(
neg_dst
.
view
(
-
1
))
if
use_timestamp
:
nodes_timestamp
.
append
(
minibatch
.
timestamp
.
repeat_interleave
(
neg_dst
.
shape
[
-
1
])
)
# Unique and compact the collected nodes.
# Unique and compact the collected nodes.
seeds
,
compacted
=
unique_and_compact
(
nodes
)
if
use_timestamp
:
seeds
,
nodes_timestamp
,
compacted
=
compact_temporal_nodes
(
nodes
,
nodes_timestamp
)
else
:
seeds
,
compacted
=
unique_and_compact
(
nodes
)
nodes_timestamp
=
None
# Map back in same order as collect.
# Map back in same order as collect.
compacted_node_pairs
=
tuple
(
compacted
[:
2
])
compacted_node_pairs
=
tuple
(
compacted
[:
2
])
compacted
=
compacted
[
2
:]
compacted
=
compacted
[
2
:]
...
@@ -132,13 +179,14 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -132,13 +179,14 @@ class SubgraphSampler(MiniBatchTransformer):
)
)
return
(
return
(
seeds
,
seeds
,
nodes_timestamp
,
compacted_node_pairs
,
compacted_node_pairs
,
compacted_negative_srcs
if
has_neg_src
else
None
,
compacted_negative_srcs
if
has_neg_src
else
None
,
compacted_negative_dsts
if
has_neg_dst
else
None
,
compacted_negative_dsts
if
has_neg_dst
else
None
,
)
)
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
=
None
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
"""Sample subgraphs from the given seeds.
"""Sample subgraphs from the given seeds
, possibly with temporal constraints
.
Any subclass of SubgraphSampler should implement this method.
Any subclass of SubgraphSampler should implement this method.
...
@@ -147,6 +195,11 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -147,6 +195,11 @@ class SubgraphSampler(MiniBatchTransformer):
seeds : Union[torch.Tensor, Dict[str, torch.Tensor]]
seeds : Union[torch.Tensor, Dict[str, torch.Tensor]]
The seed nodes.
The seed nodes.
seeds_timestamp : Union[torch.Tensor, Dict[str, torch.Tensor]]
The timestamps of the seed nodes. If given, the sampled subgraphs
should not contain any nodes or edges that are newer than the
timestamps of the seed nodes. Default: None.
Returns
Returns
-------
-------
Union[torch.Tensor, Dict[str, torch.Tensor]]
Union[torch.Tensor, Dict[str, torch.Tensor]]
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
f5981789
import
unittest
import
unittest
from
enum
import
Enum
from
functools
import
partial
from
functools
import
partial
import
backend
as
F
import
backend
as
F
...
@@ -12,6 +14,31 @@ from torchdata.datapipes.iter import Mapper
...
@@ -12,6 +14,31 @@ from torchdata.datapipes.iter import Mapper
from
.
import
gb_test_utils
from
.
import
gb_test_utils
# Skip all tests on GPU.
pytestmark
=
pytest
.
mark
.
skipif
(
F
.
_default_context_str
!=
"cpu"
,
reason
=
"GraphBolt sampling tests are only supported on CPU."
,
)
class
SamplerType
(
Enum
):
Normal
=
0
Layer
=
1
Temporal
=
2
def
_get_sampler
(
sampler_type
):
if
sampler_type
==
SamplerType
.
Normal
:
return
gb
.
NeighborSampler
if
sampler_type
==
SamplerType
.
Layer
:
return
gb
.
LayerNeighborSampler
return
partial
(
gb
.
TemporalNeighborSampler
,
node_timestamp_attr_name
=
"timestamp"
,
edge_timestamp_attr_name
=
"timestamp"
,
)
def
test_SubgraphSampler_invoke
():
def
test_SubgraphSampler_invoke
():
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
...
@@ -76,17 +103,29 @@ def test_NeighborSampler_fanouts(labor):
...
@@ -76,17 +103,29 @@ def test_NeighborSampler_fanouts(labor):
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
def
test_SubgraphSampler_Node
(
labor
):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Node
(
sampler_type
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
F
.
ctx
()
)
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
items
=
torch
.
arange
(
10
)
names
=
"seed_nodes"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
20
).
to
(
F
.
ctx
())}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
len
(
graph
.
indices
)).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
arange
(
10
))
names
=
(
"seed_nodes"
,
"timestamp"
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
s
ampler
=
_get_sampler
(
sampler_type
)
sampler_dp
=
S
ampler
(
item_sampler
,
graph
,
fanouts
)
sampler_dp
=
s
ampler
(
item_sampler
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
5
assert
len
(
list
(
sampler_dp
))
==
5
...
@@ -95,33 +134,57 @@ def to_link_batch(data):
...
@@ -95,33 +134,57 @@ def to_link_batch(data):
return
block
return
block
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
def
test_SubgraphSampler_Link
(
labor
):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Link
(
sampler_type
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
F
.
ctx
()
)
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
items
=
torch
.
arange
(
20
).
reshape
(
-
1
,
2
)
names
=
"node_pairs"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
20
).
to
(
F
.
ctx
())}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
len
(
graph
.
indices
)).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
arange
(
10
))
names
=
(
"node_pairs"
,
"timestamp"
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
s
ampler
=
_get_sampler
(
sampler_type
)
datapipe
=
S
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
s
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
def
test_SubgraphSampler_Link_With_Negative
(
labor
):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Link_With_Negative
(
sampler_type
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
F
.
ctx
()
)
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
items
=
torch
.
arange
(
20
).
reshape
(
-
1
,
2
)
names
=
"node_pairs"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
20
).
to
(
F
.
ctx
())}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
len
(
graph
.
indices
)).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
arange
(
10
))
names
=
(
"node_pairs"
,
"timestamp"
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
s
ampler
=
_get_sampler
(
sampler_type
)
datapipe
=
S
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
s
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
...
@@ -148,34 +211,64 @@ def get_hetero_graph():
...
@@ -148,34 +211,64 @@ def get_hetero_graph():
)
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
def
test_SubgraphSampler_Node_Hetero
(
labor
):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Node_Hetero
(
sampler_type
):
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
itemset
=
gb
.
ItemSetDict
(
items
=
torch
.
arange
(
3
)
{
"n2"
:
gb
.
ItemSet
(
torch
.
arange
(
3
),
names
=
"seed_nodes"
)}
names
=
"seed_nodes"
)
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
randint
(
0
,
10
,
(
3
,)))
names
=
(
names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
({
"n2"
:
gb
.
ItemSet
(
items
,
names
=
names
)})
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
s
ampler
=
_get_sampler
(
sampler_type
)
sampler_dp
=
S
ampler
(
item_sampler
,
graph
,
fanouts
)
sampler_dp
=
s
ampler
(
item_sampler
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
2
assert
len
(
list
(
sampler_dp
))
==
2
for
minibatch
in
sampler_dp
:
for
minibatch
in
sampler_dp
:
assert
len
(
minibatch
.
sampled_subgraphs
)
==
num_layer
assert
len
(
minibatch
.
sampled_subgraphs
)
==
num_layer
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
def
test_SubgraphSampler_Link_Hetero
(
labor
):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Link_Hetero
(
sampler_type
):
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
first_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
first_names
=
"node_pairs"
second_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
second_names
=
"node_pairs"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
first_items
=
(
first_items
,
torch
.
randint
(
0
,
10
,
(
4
,)))
first_names
=
(
first_names
,
"timestamp"
)
second_items
=
(
second_items
,
torch
.
randint
(
0
,
10
,
(
6
,)))
second_names
=
(
second_names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
(
itemset
=
gb
.
ItemSetDict
(
{
{
"n1:e1:n2"
:
gb
.
ItemSet
(
"n1:e1:n2"
:
gb
.
ItemSet
(
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
,
first_items
,
names
=
"node_pairs"
,
names
=
first_names
,
),
),
"n2:e2:n1"
:
gb
.
ItemSet
(
"n2:e2:n1"
:
gb
.
ItemSet
(
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
,
second_items
,
names
=
"node_pairs"
,
names
=
second_names
,
),
),
}
}
)
)
...
@@ -183,24 +276,42 @@ def test_SubgraphSampler_Link_Hetero(labor):
...
@@ -183,24 +276,42 @@ def test_SubgraphSampler_Link_Hetero(labor):
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
s
ampler
=
_get_sampler
(
sampler_type
)
datapipe
=
S
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
s
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
def
test_SubgraphSampler_Link_Hetero_With_Negative
(
labor
):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Link_Hetero_With_Negative
(
sampler_type
):
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
first_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
first_names
=
"node_pairs"
second_items
=
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
second_names
=
"node_pairs"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
first_items
=
(
first_items
,
torch
.
randint
(
0
,
10
,
(
4
,)))
first_names
=
(
first_names
,
"timestamp"
)
second_items
=
(
second_items
,
torch
.
randint
(
0
,
10
,
(
6
,)))
second_names
=
(
second_names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
(
itemset
=
gb
.
ItemSetDict
(
{
{
"n1:e1:n2"
:
gb
.
ItemSet
(
"n1:e1:n2"
:
gb
.
ItemSet
(
torch
.
LongTensor
([[
0
,
0
,
1
,
1
],
[
0
,
2
,
0
,
1
]]).
T
,
first_items
,
names
=
"node_pairs"
,
names
=
first_names
,
),
),
"n2:e2:n1"
:
gb
.
ItemSet
(
"n2:e2:n1"
:
gb
.
ItemSet
(
torch
.
LongTensor
([[
0
,
0
,
1
,
1
,
2
,
2
],
[
0
,
1
,
1
,
0
,
0
,
1
]]).
T
,
second_items
,
names
=
"node_pairs"
,
names
=
second_names
,
),
),
}
}
)
)
...
@@ -209,8 +320,8 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
...
@@ -209,8 +320,8 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
s
ampler
=
_get_sampler
(
sampler_type
)
datapipe
=
S
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
s
ampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
...
@@ -219,8 +330,11 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
...
@@ -219,8 +330,11 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
F
.
_default_context_str
!=
"cpu"
,
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Sampling with replacement not yet supported on GPU."
,
reason
=
"Sampling with replacement not yet supported on GPU."
,
)
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
def
test_SubgraphSampler_Random_Hetero_Graph
(
labor
):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_Random_Hetero_Graph
(
sampler_type
):
num_nodes
=
5
num_nodes
=
5
num_edges
=
9
num_edges
=
9
num_ntypes
=
3
num_ntypes
=
3
...
@@ -235,10 +349,14 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
...
@@ -235,10 +349,14 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
)
=
gb_test_utils
.
random_hetero_graph
(
)
=
gb_test_utils
.
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
)
node_attributes
=
{}
edge_attributes
=
{
edge_attributes
=
{
"A1"
:
torch
.
randn
(
num_edges
),
"A1"
:
torch
.
randn
(
num_edges
),
"A2"
:
torch
.
randn
(
num_edges
),
"A2"
:
torch
.
randn
(
num_edges
),
}
}
if
sampler_type
==
SamplerType
.
Temporal
:
node_attributes
[
"timestamp"
]
=
torch
.
randint
(
0
,
10
,
(
num_nodes
,))
edge_attributes
[
"timestamp"
]
=
torch
.
randint
(
0
,
10
,
(
num_edges
,))
graph
=
gb
.
fused_csc_sampling_graph
(
graph
=
gb
.
fused_csc_sampling_graph
(
csc_indptr
,
csc_indptr
,
indices
,
indices
,
...
@@ -246,21 +364,31 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
...
@@ -246,21 +364,31 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
type_per_edge
=
type_per_edge
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
node_type_to_id
,
node_type_to_id
=
node_type_to_id
,
edge_type_to_id
=
edge_type_to_id
,
edge_type_to_id
=
edge_type_to_id
,
node_attributes
=
node_attributes
,
edge_attributes
=
edge_attributes
,
edge_attributes
=
edge_attributes
,
).
to
(
F
.
ctx
())
).
to
(
F
.
ctx
())
first_items
=
torch
.
tensor
([
0
])
first_names
=
"seed_nodes"
second_items
=
torch
.
tensor
([
0
])
second_names
=
"seed_nodes"
if
sampler_type
==
SamplerType
.
Temporal
:
first_items
=
(
first_items
,
torch
.
randint
(
0
,
10
,
(
1
,)))
first_names
=
(
first_names
,
"timestamp"
)
second_items
=
(
second_items
,
torch
.
randint
(
0
,
10
,
(
1
,)))
second_names
=
(
second_names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
(
itemset
=
gb
.
ItemSetDict
(
{
{
"n2"
:
gb
.
ItemSet
(
torch
.
tensor
([
0
]),
names
=
"seed_nod
es
"
),
"n2"
:
gb
.
ItemSet
(
first_items
,
names
=
first_nam
es
),
"n1"
:
gb
.
ItemSet
(
torch
.
tensor
([
0
])
,
names
=
"
se
e
d_n
od
es
"
),
"n1"
:
gb
.
ItemSet
(
second_items
,
names
=
se
con
d_n
am
es
),
}
}
)
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
S
ampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
s
ampler
=
_get_sampler
(
sampler_type
)
sampler_dp
=
S
ampler
(
item_sampler
,
graph
,
fanouts
,
replace
=
True
)
sampler_dp
=
s
ampler
(
item_sampler
,
graph
,
fanouts
,
replace
=
True
)
for
data
in
sampler_dp
:
for
data
in
sampler_dp
:
for
sampledsubgraph
in
data
.
sampled_subgraphs
:
for
sampledsubgraph
in
data
.
sampled_subgraphs
:
...
@@ -289,23 +417,40 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
...
@@ -289,23 +417,40 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
F
.
_default_context_str
!=
"cpu"
,
F
.
_default_context_str
!=
"cpu"
,
reason
=
"Fails due to randomness on the GPU."
,
reason
=
"Fails due to randomness on the GPU."
,
)
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
def
test_SubgraphSampler_without_dedpulication_Homo
(
labor
):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_without_dedpulication_Homo
(
sampler_type
):
graph
=
dgl
.
graph
(
graph
=
dgl
.
graph
(
([
5
,
0
,
1
,
5
,
6
,
7
,
2
,
2
,
4
],
[
0
,
1
,
2
,
2
,
2
,
2
,
3
,
4
,
4
])
([
5
,
0
,
1
,
5
,
6
,
7
,
2
,
2
,
4
],
[
0
,
1
,
2
,
2
,
2
,
2
,
3
,
4
,
4
])
)
)
graph
=
gb
.
from_dglgraph
(
graph
,
True
).
to
(
F
.
ctx
())
graph
=
gb
.
from_dglgraph
(
graph
,
True
).
to
(
F
.
ctx
())
seed_nodes
=
torch
.
LongTensor
([
0
,
3
,
4
])
seed_nodes
=
torch
.
LongTensor
([
0
,
3
,
4
])
items
=
seed_nodes
names
=
"seed_nodes"
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
randint
(
0
,
10
,
(
3
,)))
names
=
(
names
,
"timestamp"
)
itemset
=
gb
.
ItemSet
(
seed_nodes
,
names
=
"seed_nod
es
"
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
nam
es
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
len
(
seed_nodes
)).
copy_to
(
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
len
(
seed_nodes
)).
copy_to
(
F
.
ctx
()
F
.
ctx
()
)
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
sampler
=
_get_sampler
(
sampler_type
)
datapipe
=
Sampler
(
item_sampler
,
graph
,
fanouts
,
deduplicate
=
False
)
if
sampler_type
==
SamplerType
.
Temporal
:
datapipe
=
sampler
(
item_sampler
,
graph
,
fanouts
)
else
:
datapipe
=
sampler
(
item_sampler
,
graph
,
fanouts
,
deduplicate
=
False
)
length
=
[
17
,
7
]
length
=
[
17
,
7
]
compacted_indices
=
[
compacted_indices
=
[
...
@@ -334,17 +479,32 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
...
@@ -334,17 +479,32 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
)
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
def
test_SubgraphSampler_without_dedpulication_Hetero
(
labor
):
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
def
test_SubgraphSampler_without_dedpulication_Hetero
(
sampler_type
):
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
itemset
=
gb
.
ItemSetDict
(
items
=
torch
.
arange
(
2
)
{
"n2"
:
gb
.
ItemSet
(
torch
.
arange
(
2
),
names
=
"seed_nodes"
)}
names
=
"seed_nodes"
)
if
sampler_type
==
SamplerType
.
Temporal
:
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
items
=
(
items
,
torch
.
randint
(
0
,
10
,
(
2
,)))
names
=
(
names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
({
"n2"
:
gb
.
ItemSet
(
items
,
names
=
names
)})
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
sampler
=
_get_sampler
(
sampler_type
)
datapipe
=
Sampler
(
item_sampler
,
graph
,
fanouts
,
deduplicate
=
False
)
if
sampler_type
==
SamplerType
.
Temporal
:
datapipe
=
sampler
(
item_sampler
,
graph
,
fanouts
)
else
:
datapipe
=
sampler
(
item_sampler
,
graph
,
fanouts
,
deduplicate
=
False
)
csc_formats
=
[
csc_formats
=
[
{
{
"n1:e1:n2"
:
gb
.
CSCFormatBase
(
"n1:e1:n2"
:
gb
.
CSCFormatBase
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment