Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4135b1bd
Unverified
Commit
4135b1bd
authored
Jul 20, 2023
by
AdamGrabowski
Committed by
GitHub
Jul 20, 2023
Browse files
[Performance] Fused sampling with compaction (#5924)
Co-authored-by:
Hesham Mostafa
<
hesham.mostafa@intel.com
>
parent
4ceb0bff
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1280 additions
and
82 deletions
+1280
-82
benchmarks/benchmarks/api/bench_fused_sample_neighbors.py
benchmarks/benchmarks/api/bench_fused_sample_neighbors.py
+39
-0
include/dgl/aten/csr.h
include/dgl/aten/csr.h
+66
-0
include/dgl/sampling/neighbor.h
include/dgl/sampling/neighbor.h
+50
-0
python/dgl/dataloading/neighbor_sampler.py
python/dgl/dataloading/neighbor_sampler.py
+39
-0
python/dgl/sampling/neighbor.py
python/dgl/sampling/neighbor.py
+210
-16
src/array/array.cc
src/array/array.cc
+41
-0
src/array/array_op.h
src/array/array_op.h
+13
-0
src/array/cpu/concurrent_id_hash_map.cc
src/array/cpu/concurrent_id_hash_map.cc
+22
-0
src/array/cpu/concurrent_id_hash_map.h
src/array/cpu/concurrent_id_hash_map.h
+3
-0
src/array/cpu/rowwise_pick.h
src/array/cpu/rowwise_pick.h
+110
-0
src/array/cpu/rowwise_sampling.cc
src/array/cpu/rowwise_sampling.cc
+95
-0
src/graph/sampling/neighbor/neighbor.cc
src/graph/sampling/neighbor/neighbor.cc
+350
-0
src/graph/unit_graph.cc
src/graph/unit_graph.cc
+30
-0
src/graph/unit_graph.h
src/graph/unit_graph.h
+12
-0
tests/python/common/sampling/test_sampling.py
tests/python/common/sampling/test_sampling.py
+200
-66
No files found.
benchmarks/benchmarks/api/bench_fused_sample_neighbors.py
0 → 100644
View file @
4135b1bd
import
time
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
torch
from
..
import
utils
@
utils
.
benchmark
(
"time"
)
@
utils
.
parametrize_cpu
(
"graph_name"
,
[
"livejournal"
,
"reddit"
])
@
utils
.
parametrize_gpu
(
"graph_name"
,
[
"ogbn-arxiv"
,
"reddit"
])
@
utils
.
parametrize
(
"format"
,
[
"csr"
,
"csc"
])
@
utils
.
parametrize
(
"seed_nodes_num"
,
[
200
,
5000
,
20000
])
@
utils
.
parametrize
(
"fanout"
,
[
5
,
20
,
40
])
def
track_time
(
graph_name
,
format
,
seed_nodes_num
,
fanout
):
device
=
utils
.
get_bench_device
()
graph
=
utils
.
get_graph
(
graph_name
,
format
).
to
(
device
)
edge_dir
=
"in"
if
format
==
"csc"
else
"out"
seed_nodes
=
np
.
random
.
randint
(
0
,
graph
.
num_nodes
(),
seed_nodes_num
)
seed_nodes
=
torch
.
from_numpy
(
seed_nodes
).
to
(
device
)
# dry run
for
i
in
range
(
3
):
dgl
.
sampling
.
sample_neighbors_fused
(
graph
,
seed_nodes
,
fanout
,
edge_dir
=
edge_dir
)
# timing
with
utils
.
Timer
()
as
t
:
for
i
in
range
(
50
):
dgl
.
sampling
.
sample_neighbors_fused
(
graph
,
seed_nodes
,
fanout
,
edge_dir
=
edge_dir
)
return
t
.
elapsed_secs
/
50
include/dgl/aten/csr.h
View file @
4135b1bd
...
...
@@ -572,6 +572,72 @@ COOMatrix CSRRowWiseSampling(
CSRMatrix
mat
,
IdArray
rows
,
int64_t
num_samples
,
NDArray
prob_or_mask
=
NDArray
(),
bool
replace
=
true
);
/*!
* @brief Randomly select a fixed number of non-zero entries along each given
* row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a CSR matrix, with
* additional IdArray that is an extended version of CSR's index pointers.
*
* With template parameter set to True rows are also saved as new seed nodes and
* mapped
*
* If replace is false and a row has fewer non-zero values than num_samples,
* all the values are picked.
*
* Examples:
*
* // csr.num_rows = 4;
* // csr.num_cols = 4;
* // csr.indptr = [0, 2, 3, 3, 5]
* // csr.indices = [0, 1, 1, 2, 3]
* // csr.data = [2, 3, 0, 1, 4]
* CSRMatrix csr = ...;
* IdArray rows = ... ; // [1, 3]
* IdArray seed_mapping = [-1, -1, -1, -1];
* std::vector<IdType> new_seed_nodes = {};
*
* std::pair<CSRMatrix, IdArray> sampled = CSRRowWiseSamplingFused<
* typename IdType, True>(
* csr, rows, seed_mapping,
* new_seed_nodes, 2,
* FloatArray(), false);
* // possible sampled csr matrix:
* // sampled.first.num_rows = 2
* // sampled.first.num_cols = 3
* // sampled.first.indptr = [0, 1, 3]
* // sampled.first.indices = [1, 2, 3]
* // sampled.first.data = [0, 1, 4]
* // sampled.second = [0, 1, 1]
* // seed_mapping = [-1, 0, -1, 1];
* // new_seed_nodes = {1, 3};
*
* @tparam IdType Graph's index data type, can be int32_t or int64_t
* @tparam map_seed_nodes If set for true we map and copy rows to new_seed_nodes
* @param mat Input CSR matrix.
* @param rows Rows to sample from.
* @param seed_mapping Mapping array used if map_seed_nodes=true. If so each row
* from rows will be set to its position e.g. mapping[rows[i]] = i.
* @param new_seed_nodes Vector used if map_seed_nodes=true. If so it will
* contain rows.
* @param rows Rows to sample from.
* @param num_samples Number of samples
* @param prob_or_mask Unnormalized probability array or mask array.
* Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* @param replace True if sample with replacement
* @return A CSRMatrix storing the picked row, col and data indices,
* COO version of picked rows
* @note The edges of the entire graph must be ordered by their edge types,
* rows must be unique
*/
template
<
typename
IdType
,
bool
map_seed_nodes
>
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
(
CSRMatrix
mat
,
IdArray
rows
,
IdArray
seed_mapping
,
std
::
vector
<
IdType
>*
new_seed_nodes
,
int64_t
num_samples
,
NDArray
prob_or_mask
=
NDArray
(),
bool
replace
=
true
);
/**
* @brief Randomly select a fixed number of non-zero entries for each edge type
* along each given row independently.
...
...
include/dgl/sampling/neighbor.h
View file @
4135b1bd
...
...
@@ -9,6 +9,7 @@
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <tuple>
#include <vector>
namespace
dgl
{
...
...
@@ -47,6 +48,55 @@ HeteroSubgraph SampleNeighbors(
const
std
::
vector
<
FloatArray
>&
probability
,
const
std
::
vector
<
IdArray
>&
exclude_edges
,
bool
replace
=
true
);
/**
* @brief Sample from the neighbors of the given nodes and convert a graph into
* a bipartite-structured graph for message passing.
*
* Specifically, we create one node type \c ntype_l on the "left" side and
* another node type \c ntype_r on the "right" side for each node type \c ntype.
* The nodes of type \c ntype_r would contain the nodes designated by the
* caller, and node type \c ntype_l would contain the nodes that has an edge
* connecting to one of the designated nodes.
*
* The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r.
* When sampling with replacement, the sampled subgraph could have parallel
* edges.
*
* For sampling without replace, if fanout > the number of neighbors, all the
* neighbors will be sampled.
*
* Non-deterministic algorithm, requires nodes parameter to store unique Node
* IDs.
*
* @tparam IdType Graph's index data type, can be int32_t or int64_t
* @param hg The input graph.
* @param nodes Node IDs of each type. The vector length must be equal to the
* number of node types. Empty array is allowed.
* @param mapping External parameter that should be set to a vector of IdArrays
* filled with -1, required for mapping of nodes in returned
* graph
* @param fanouts Number of sampled neighbors for each edge type. The vector
* length should be equal to the number of edge types, or one if they all have
* the same fanout.
* @param dir Edge direction.
* @param probability A vector of 1D float arrays, indicating the transition
* probability of each edge by edge type. An empty float array assumes uniform
* transition.
* @param exclude_edges Edges IDs of each type which will be excluded during
* sampling. The vector length must be equal to the number of edges types. Empty
* array is allowed.
* @param replace If true, sample with replacement.
* @return Sampled neighborhoods as a graph. The return graph has the same
* schema as the original one.
*/
template
<
typename
IdType
>
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>
,
std
::
vector
<
IdArray
>>
SampleNeighborsFused
(
const
HeteroGraphPtr
hg
,
const
std
::
vector
<
IdArray
>&
nodes
,
const
std
::
vector
<
IdArray
>&
mapping
,
const
std
::
vector
<
int64_t
>&
fanouts
,
EdgeDir
dir
,
const
std
::
vector
<
NDArray
>&
prob_or_mask
,
const
std
::
vector
<
IdArray
>&
exclude_edges
,
bool
replace
=
true
);
/**
* Select the neighbors with k-largest weights on the connecting edges for each
* given node.
...
...
python/dgl/dataloading/neighbor_sampler.py
View file @
4135b1bd
"""Data loading components for neighbor sampling"""
from
..
import
backend
as
F
from
..base
import
EID
,
NID
from
..heterograph
import
DGLGraph
from
..transforms
import
to_block
from
.base
import
BlockSampler
...
...
@@ -54,6 +56,9 @@ class NeighborSampler(BlockSampler):
output_device : device, optional
The device of the output subgraphs or MFGs. Default is the same as the
minibatch of seed nodes.
fused : bool, default True
If True and device is CPU fused sample neighbors is invoked. This version
requires seed_nodes to be unique
Examples
--------
...
...
@@ -120,6 +125,7 @@ class NeighborSampler(BlockSampler):
prefetch_labels
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
,
fused
=
True
,
):
super
().
__init__
(
prefetch_node_feats
=
prefetch_node_feats
,
...
...
@@ -137,10 +143,43 @@ class NeighborSampler(BlockSampler):
)
self
.
prob
=
prob
or
mask
self
.
replace
=
replace
self
.
fused
=
fused
self
.
mapping
=
{}
self
.
g
=
None
def
sample_blocks
(
self
,
g
,
seed_nodes
,
exclude_eids
=
None
):
output_nodes
=
seed_nodes
blocks
=
[]
if
self
.
fused
:
cpu
=
F
.
device_type
(
g
.
device
)
==
"cpu"
if
isinstance
(
seed_nodes
,
dict
):
for
ntype
in
list
(
seed_nodes
.
keys
()):
if
not
cpu
:
break
cpu
=
(
cpu
and
F
.
device_type
(
seed_nodes
[
ntype
].
device
)
==
"cpu"
)
else
:
cpu
=
cpu
and
F
.
device_type
(
seed_nodes
.
device
)
==
"cpu"
if
cpu
and
isinstance
(
g
,
DGLGraph
)
and
F
.
backend_name
==
"pytorch"
:
if
self
.
g
!=
g
:
self
.
mapping
=
{}
self
.
g
=
g
for
fanout
in
reversed
(
self
.
fanouts
):
block
=
g
.
sample_neighbors_fused
(
seed_nodes
,
fanout
,
edge_dir
=
self
.
edge_dir
,
prob
=
self
.
prob
,
replace
=
self
.
replace
,
exclude_edges
=
exclude_eids
,
mapping
=
self
.
mapping
,
)
seed_nodes
=
block
.
srcdata
[
NID
]
blocks
.
insert
(
0
,
block
)
return
seed_nodes
,
output_nodes
,
blocks
for
fanout
in
reversed
(
self
.
fanouts
):
frontier
=
g
.
sample_neighbors
(
seed_nodes
,
...
...
python/dgl/sampling/neighbor.py
View file @
4135b1bd
"""Neighbor sampling APIs"""
import
os
import
torch
from
..
import
backend
as
F
,
ndarray
as
nd
,
utils
from
.._ffi.function
import
_init_api
from
..base
import
DGLError
,
EID
from
..heterograph
import
DGLGraph
from
..heterograph
import
DGLBlock
,
DGLGraph
from
.utils
import
EidExcluder
__all__
=
[
"sample_etype_neighbors"
,
"sample_neighbors"
,
"sample_neighbors_fused"
,
"sample_neighbors_biased"
,
"select_topk"
,
]
...
...
@@ -379,6 +384,126 @@ def sample_neighbors(
return
frontier
if
output_device
is
None
else
frontier
.
to
(
output_device
)
def
sample_neighbors_fused
(
g
,
nodes
,
fanout
,
edge_dir
=
"in"
,
prob
=
None
,
replace
=
False
,
copy_ndata
=
True
,
copy_edata
=
True
,
exclude_edges
=
None
,
mapping
=
None
,
):
"""Sample neighboring edges of the given nodes and return the induced subgraph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
will be randomly chosen. The graph returned will then contain all the nodes in the
original graph, but only the sampled edges. Nodes will be renumbered starting from id 0,
which would be new node id of first seed node.
Parameters
----------
g : DGLGraph
The graph. Can be either on CPU or GPU.
nodes : tensor or dict
Node IDs to sample neighbors from.
This argument can take a single ID tensor or a dictionary of node types and ID tensors.
If a single tensor is given, the graph must only have one type of nodes.
fanout : int or dict[etype, int]
The number of edges to be sampled for each node on each edge type.
This argument can take a single int or a dictionary of edge types and ints.
If a single int is given, DGL will sample this number of edges for each node for
every edge type.
If -1 is given for a single edge type, all the neighboring edges with that edge
type and non-zero probability will be selected.
edge_dir : str, optional
Determines whether to sample inbound or outbound edges.
Can take either ``in`` for inbound edges or ``out`` for outbound edges.
prob : str, optional
Feature name used as the (unnormalized) probabilities associated with each
neighboring edge of a node. The feature must have only one element for each
edge.
The features must be non-negative floats or boolean. Otherwise, the result
will be undefined.
exclude_edges: tensor or dict
Edge IDs to exclude during sampling neighbors for the seed nodes.
This argument can take a single ID tensor or a dictionary of edge types and ID tensors.
If a single tensor is given, the graph must only have one type of nodes.
replace : bool, optional
If True, sample with replacement.
copy_ndata: bool, optional
If True, the node features of the new graph are copied from
the original graph. If False, the new graph will not have any
node features.
(Default: True)
copy_edata: bool, optional
If True, the edge features of the new graph are copied from
the original graph. If False, the new graph will not have any
edge features.
(Default: False)
mapping : dictionary, optional
Used by fused version of NeighborSampler. To avoid constant data allocation
provide empty dictionary ({}) that will be allocated once with proper data and reused
by each function call
(Default: None)
Returns
-------
DGLGraph
A sampled subgraph containing only the sampled neighboring edges.
Notes
-----
If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as
the node or edge features of the original graph and the new graph.
As a result, users should avoid performing in-place operations
on the node features of the new graph to avoid feature corruption.
"""
if
not
g
.
is_pinned
():
frontier
=
_sample_neighbors
(
g
,
nodes
,
fanout
,
edge_dir
=
edge_dir
,
prob
=
prob
,
replace
=
replace
,
copy_ndata
=
copy_ndata
,
copy_edata
=
copy_edata
,
exclude_edges
=
exclude_edges
,
fused
=
True
,
mapping
=
mapping
,
)
else
:
frontier
=
_sample_neighbors
(
g
,
nodes
,
fanout
,
edge_dir
=
edge_dir
,
prob
=
prob
,
replace
=
replace
,
copy_ndata
=
copy_ndata
,
copy_edata
=
copy_edata
,
fused
=
True
,
mapping
=
mapping
,
)
if
exclude_edges
is
not
None
:
eid_excluder
=
EidExcluder
(
exclude_edges
)
frontier
=
eid_excluder
(
frontier
)
return
frontier
def
_sample_neighbors
(
g
,
nodes
,
...
...
@@ -390,6 +515,8 @@ def _sample_neighbors(
copy_edata
=
True
,
_dist_training
=
False
,
exclude_edges
=
None
,
fused
=
False
,
mapping
=
None
,
):
if
not
isinstance
(
nodes
,
dict
):
if
len
(
g
.
ntypes
)
>
1
:
...
...
@@ -446,6 +573,53 @@ def _sample_neighbors(
else
:
excluded_edges_all_t
.
append
(
nd
.
array
([],
ctx
=
ctx
))
if
fused
:
if
_dist_training
:
raise
DGLError
(
"distributed training not supported in fused sampling"
)
cpu
=
F
.
device_type
(
g
.
device
)
==
"cpu"
if
isinstance
(
nodes
,
dict
):
for
ntype
in
list
(
nodes
.
keys
()):
if
not
cpu
:
break
cpu
=
cpu
and
F
.
device_type
(
nodes
[
ntype
].
device
)
==
"cpu"
else
:
cpu
=
cpu
and
F
.
device_type
(
nodes
.
device
)
==
"cpu"
if
not
cpu
or
F
.
backend_name
!=
"pytorch"
:
raise
DGLError
(
"Only PyTorch backend and cpu is supported in fused sampling"
)
if
mapping
is
None
:
mapping
=
{}
mapping_name
=
"__mapping"
+
str
(
os
.
getpid
())
if
mapping_name
not
in
mapping
.
keys
():
mapping
[
mapping_name
]
=
[
torch
.
LongTensor
(
g
.
num_nodes
(
ntype
)).
fill_
(
-
1
)
for
ntype
in
g
.
ntypes
]
subgidx
,
induced_nodes
,
induced_edges
=
_CAPI_DGLSampleNeighborsFused
(
g
.
_graph
,
nodes_all_types
,
[
F
.
to_dgl_nd
(
m
)
for
m
in
mapping
[
mapping_name
]],
fanout_array
,
edge_dir
,
prob_arrays
,
excluded_edges_all_t
,
replace
,
)
for
mapping_vector
,
src_nodes
in
zip
(
mapping
[
mapping_name
],
induced_nodes
):
mapping_vector
[
F
.
from_dgl_nd
(
src_nodes
).
type
(
F
.
int64
)]
=
-
1
new_ntypes
=
(
g
.
ntypes
,
g
.
ntypes
)
ret
=
DGLBlock
(
subgidx
,
new_ntypes
,
g
.
etypes
)
assert
ret
.
is_unibipartite
else
:
subgidx
=
_CAPI_DGLSampleNeighbors
(
g
.
_graph
,
nodes_all_types
,
...
...
@@ -455,8 +629,8 @@ def _sample_neighbors(
excluded_edges_all_t
,
replace
,
)
induced_edges
=
subgidx
.
induced_edges
ret
=
DGLGraph
(
subgidx
.
graph
,
g
.
ntypes
,
g
.
etypes
)
induced_edges
=
subgidx
.
induced_edges
# handle features
# (TODO) (BarclayII) DGL distributed fails with bus error, freezes, or other
...
...
@@ -465,12 +639,31 @@ def _sample_neighbors(
# only set the edge IDs.
if
not
_dist_training
:
if
copy_ndata
:
if
fused
:
src_node_ids
=
[
F
.
from_dgl_nd
(
src
)
for
src
in
induced_nodes
]
dst_node_ids
=
[
utils
.
toindex
(
nodes
.
get
(
ntype
,
[]),
g
.
_idtype_str
).
tousertensor
(
ctx
=
F
.
to_backend_ctx
(
g
.
_graph
.
ctx
))
for
ntype
in
g
.
ntypes
]
node_frames
=
utils
.
extract_node_subframes_for_block
(
g
,
src_node_ids
,
dst_node_ids
)
utils
.
set_new_frames
(
ret
,
node_frames
=
node_frames
)
else
:
node_frames
=
utils
.
extract_node_subframes
(
g
,
device
)
utils
.
set_new_frames
(
ret
,
node_frames
=
node_frames
)
if
copy_edata
:
if
fused
:
edge_ids
=
[
F
.
from_dgl_nd
(
eid
)
for
eid
in
induced_edges
]
edge_frames
=
utils
.
extract_edge_subframes
(
g
,
edge_ids
)
utils
.
set_new_frames
(
ret
,
edge_frames
=
edge_frames
)
else
:
edge_frames
=
utils
.
extract_edge_subframes
(
g
,
induced_edges
)
utils
.
set_new_frames
(
ret
,
edge_frames
=
edge_frames
)
else
:
for
i
,
etype
in
enumerate
(
ret
.
canonical_etypes
):
ret
.
edges
[
etype
].
data
[
EID
]
=
induced_edges
[
i
]
...
...
@@ -479,6 +672,7 @@ def _sample_neighbors(
DGLGraph
.
sample_neighbors
=
utils
.
alias_func
(
sample_neighbors
)
DGLGraph
.
sample_neighbors_fused
=
utils
.
alias_func
(
sample_neighbors_fused
)
def
sample_neighbors_biased
(
...
...
src/array/array.cc
View file @
4135b1bd
...
...
@@ -597,6 +597,47 @@ COOMatrix CSRRowWiseSampling(
return
ret
;
}
template
<
typename
IdType
,
bool
map_seed_nodes
>
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
(
CSRMatrix
mat
,
IdArray
rows
,
IdArray
seed_mapping
,
std
::
vector
<
IdType
>*
new_seed_nodes
,
int64_t
num_samples
,
NDArray
prob_or_mask
,
bool
replace
)
{
std
::
pair
<
CSRMatrix
,
IdArray
>
ret
;
if
(
IsNullArray
(
prob_or_mask
))
{
ATEN_XPU_SWITCH
(
rows
->
ctx
.
device_type
,
XPU
,
"CSRRowWiseSamplingUniformFused"
,
{
ret
=
impl
::
CSRRowWiseSamplingUniformFused
<
XPU
,
IdType
,
map_seed_nodes
>
(
mat
,
rows
,
seed_mapping
,
new_seed_nodes
,
num_samples
,
replace
);
});
}
else
{
CHECK_VALID_CONTEXT
(
prob_or_mask
,
rows
);
ATEN_XPU_SWITCH
(
rows
->
ctx
.
device_type
,
XPU
,
"CSRRowWiseSamplingFused"
,
{
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH
(
prob_or_mask
->
dtype
,
FloatType
,
"probability or mask"
,
{
ret
=
impl
::
CSRRowWiseSamplingFused
<
XPU
,
IdType
,
FloatType
,
map_seed_nodes
>
(
mat
,
rows
,
seed_mapping
,
new_seed_nodes
,
num_samples
,
prob_or_mask
,
replace
);
});
});
}
return
ret
;
}
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
int64_t
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
int64_t
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
int32_t
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
int32_t
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
NDArray
,
bool
);
COOMatrix
CSRRowWisePerEtypeSampling
(
CSRMatrix
mat
,
IdArray
rows
,
const
std
::
vector
<
int64_t
>&
eid2etype_offset
,
const
std
::
vector
<
int64_t
>&
num_samples
,
...
...
src/array/array_op.h
View file @
4135b1bd
...
...
@@ -178,6 +178,14 @@ COOMatrix CSRRowWiseSampling(
CSRMatrix
mat
,
IdArray
rows
,
int64_t
num_samples
,
NDArray
prob_or_mask
,
bool
replace
);
// FloatType is the type of probability data.
template
<
DGLDeviceType
XPU
,
typename
IdxType
,
typename
DType
,
bool
map_seed_nodes
>
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
(
CSRMatrix
mat
,
IdArray
rows
,
IdArray
seed_mapping
,
std
::
vector
<
IdxType
>*
new_seed_nodes
,
int64_t
num_samples
,
NDArray
prob_or_mask
,
bool
replace
);
// FloatType is the type of probability data.
template
<
DGLDeviceType
XPU
,
typename
IdType
,
typename
DType
>
COOMatrix
CSRRowWisePerEtypeSampling
(
...
...
@@ -190,6 +198,11 @@ template <DGLDeviceType XPU, typename IdType>
COOMatrix
CSRRowWiseSamplingUniform
(
CSRMatrix
mat
,
IdArray
rows
,
int64_t
num_samples
,
bool
replace
);
template
<
DGLDeviceType
XPU
,
typename
IdType
,
bool
map_seed_nodes
>
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingUniformFused
(
CSRMatrix
mat
,
IdArray
rows
,
IdArray
seed_mapping
,
std
::
vector
<
IdType
>*
new_seed_nodes
,
int64_t
num_samples
,
bool
replace
);
template
<
DGLDeviceType
XPU
,
typename
IdType
>
COOMatrix
CSRRowWisePerEtypeSamplingUniform
(
CSRMatrix
mat
,
IdArray
rows
,
const
std
::
vector
<
int64_t
>&
eid2etype_offset
,
...
...
src/array/cpu/concurrent_id_hash_map.cc
View file @
4135b1bd
...
...
@@ -223,5 +223,27 @@ ConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) {
template
class
ConcurrentIdHashMap
<
int32_t
>;
template
class
ConcurrentIdHashMap
<
int64_t
>;
template
<
typename
IdType
>
bool
BoolCompareAndSwap
(
IdType
*
ptr
)
{
#ifdef _MSC_VER
if
(
sizeof
(
IdType
)
==
4
)
{
return
_InterlockedCompareExchange
(
reinterpret_cast
<
LONG
*>
(
ptr
),
0
,
-
1
)
==
-
1
;
}
else
if
(
sizeof
(
IdType
)
==
8
)
{
return
_InterlockedCompareExchange64
(
reinterpret_cast
<
LONGLONG
*>
(
ptr
),
0
,
-
1
)
==
-
1
;
}
else
{
LOG
(
FATAL
)
<<
"ID can only be int32 or int64"
;
}
#elif __GNUC__ // _MSC_VER
return
__sync_bool_compare_and_swap
(
ptr
,
-
1
,
0
);
#else // _MSC_VER
#error "CompareAndSwap is not supported on this platform."
#endif // _MSC_VER
}
template
bool
BoolCompareAndSwap
<
int32_t
>(
int32_t
*
);
template
bool
BoolCompareAndSwap
<
int64_t
>(
int64_t
*
);
}
// namespace aten
}
// namespace dgl
src/array/cpu/concurrent_id_hash_map.h
View file @
4135b1bd
...
...
@@ -195,6 +195,9 @@ class ConcurrentIdHashMap {
IdType
mask_
;
};
template
<
typename
IdType
>
bool
BoolCompareAndSwap
(
IdType
*
ptr
);
}
// namespace aten
}
// namespace dgl
...
...
src/array/cpu/rowwise_pick.h
View file @
4135b1bd
...
...
@@ -14,6 +14,7 @@
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace
dgl
{
...
...
@@ -94,6 +95,115 @@ using EtypeRangePickFn = std::function<void(
const
std
::
vector
<
IdxType
>&
et_idx
,
const
std
::
vector
<
IdxType
>&
et_eid
,
const
IdxType
*
eid
,
IdxType
*
out_idx
)
>
;
template
<
typename
IdxType
,
bool
map_seed_nodes
>
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWisePickFused
(
CSRMatrix
mat
,
IdArray
rows
,
IdArray
seed_mapping
,
std
::
vector
<
IdxType
>*
new_seed_nodes
,
int64_t
num_picks
,
bool
replace
,
PickFn
<
IdxType
>
pick_fn
,
NumPicksFn
<
IdxType
>
num_picks_fn
)
{
using
namespace
aten
;
const
IdxType
*
indptr
=
static_cast
<
IdxType
*>
(
mat
.
indptr
->
data
);
const
IdxType
*
indices
=
static_cast
<
IdxType
*>
(
mat
.
indices
->
data
);
const
IdxType
*
data
=
CSRHasData
(
mat
)
?
static_cast
<
IdxType
*>
(
mat
.
data
->
data
)
:
nullptr
;
const
IdxType
*
rows_data
=
static_cast
<
IdxType
*>
(
rows
->
data
);
const
int64_t
num_rows
=
rows
->
shape
[
0
];
const
auto
&
ctx
=
mat
.
indptr
->
ctx
;
const
auto
&
idtype
=
mat
.
indptr
->
dtype
;
IdxType
*
seed_mapping_data
=
nullptr
;
if
(
map_seed_nodes
)
seed_mapping_data
=
seed_mapping
.
Ptr
<
IdxType
>
();
const
int
num_threads
=
runtime
::
compute_num_threads
(
0
,
num_rows
,
1
);
std
::
vector
<
int64_t
>
global_prefix
(
num_threads
+
1
,
0
);
IdArray
picked_col
,
picked_idx
,
picked_coo_rows
;
IdArray
block_csr_indptr
=
IdArray
::
Empty
({
num_rows
+
1
},
idtype
,
ctx
);
IdxType
*
block_csr_indptr_data
=
block_csr_indptr
.
Ptr
<
IdxType
>
();
#pragma omp parallel num_threads(num_threads)
{
const
int
thread_id
=
omp_get_thread_num
();
const
int64_t
start_i
=
thread_id
*
(
num_rows
/
num_threads
)
+
std
::
min
(
static_cast
<
int64_t
>
(
thread_id
),
num_rows
%
num_threads
);
const
int64_t
end_i
=
(
thread_id
+
1
)
*
(
num_rows
/
num_threads
)
+
std
::
min
(
static_cast
<
int64_t
>
(
thread_id
+
1
),
num_rows
%
num_threads
);
assert
(
thread_id
+
1
<
num_threads
||
end_i
==
num_rows
);
const
int64_t
num_local
=
end_i
-
start_i
;
std
::
unique_ptr
<
int64_t
[]
>
local_prefix
(
new
int64_t
[
num_local
+
1
]);
local_prefix
[
0
]
=
0
;
for
(
int64_t
i
=
start_i
;
i
<
end_i
;
++
i
)
{
// build prefix-sum
const
int64_t
local_i
=
i
-
start_i
;
const
IdxType
rid
=
rows_data
[
i
];
if
(
map_seed_nodes
)
seed_mapping_data
[
rid
]
=
i
;
IdxType
len
=
num_picks_fn
(
rid
,
indptr
[
rid
],
indptr
[
rid
+
1
]
-
indptr
[
rid
],
indices
,
data
);
local_prefix
[
local_i
+
1
]
=
local_prefix
[
local_i
]
+
len
;
}
global_prefix
[
thread_id
+
1
]
=
local_prefix
[
num_local
];
#pragma omp barrier
#pragma omp master
{
for
(
int
t
=
0
;
t
<
num_threads
;
++
t
)
{
global_prefix
[
t
+
1
]
+=
global_prefix
[
t
];
}
picked_col
=
IdArray
::
Empty
({
global_prefix
[
num_threads
]},
idtype
,
ctx
);
picked_idx
=
IdArray
::
Empty
({
global_prefix
[
num_threads
]},
idtype
,
ctx
);
picked_coo_rows
=
IdArray
::
Empty
({
global_prefix
[
num_threads
]},
idtype
,
ctx
);
}
#pragma omp barrier
IdxType
*
picked_cdata
=
picked_col
.
Ptr
<
IdxType
>
();
IdxType
*
picked_idata
=
picked_idx
.
Ptr
<
IdxType
>
();
IdxType
*
picked_rows
=
picked_coo_rows
.
Ptr
<
IdxType
>
();
const
IdxType
thread_offset
=
global_prefix
[
thread_id
];
for
(
int64_t
i
=
start_i
;
i
<
end_i
;
++
i
)
{
const
IdxType
rid
=
rows_data
[
i
];
const
int64_t
local_i
=
i
-
start_i
;
block_csr_indptr_data
[
i
]
=
local_prefix
[
local_i
]
+
thread_offset
;
const
IdxType
off
=
indptr
[
rid
];
const
IdxType
len
=
indptr
[
rid
+
1
]
-
off
;
if
(
len
==
0
)
continue
;
const
int64_t
row_offset
=
local_prefix
[
local_i
]
+
thread_offset
;
const
int64_t
num_picks
=
local_prefix
[
local_i
+
1
]
+
thread_offset
-
row_offset
;
pick_fn
(
rid
,
off
,
len
,
num_picks
,
indices
,
data
,
picked_idata
+
row_offset
);
for
(
int64_t
j
=
0
;
j
<
num_picks
;
++
j
)
{
const
IdxType
picked
=
picked_idata
[
row_offset
+
j
];
picked_cdata
[
row_offset
+
j
]
=
indices
[
picked
];
picked_idata
[
row_offset
+
j
]
=
data
?
data
[
picked
]
:
picked
;
picked_rows
[
row_offset
+
j
]
=
i
;
}
}
}
block_csr_indptr_data
[
num_rows
]
=
global_prefix
.
back
();
const
IdxType
num_cols
=
picked_col
->
shape
[
0
];
if
(
map_seed_nodes
)
{
(
*
new_seed_nodes
).
resize
(
num_rows
);
memcpy
((
*
new_seed_nodes
).
data
(),
rows_data
,
sizeof
(
IdxType
)
*
num_rows
);
}
return
std
::
make_pair
(
CSRMatrix
(
num_rows
,
num_cols
,
block_csr_indptr
,
picked_col
,
picked_idx
),
picked_coo_rows
);
}
// Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation
// independently.
...
...
src/array/cpu/rowwise_sampling.cc
View file @
4135b1bd
...
...
@@ -225,6 +225,74 @@ template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(
template
COOMatrix
CSRRowWiseSampling
<
kDGLCPU
,
int64_t
,
uint8_t
>(
CSRMatrix
,
IdArray
,
int64_t
,
NDArray
,
bool
);
template
<
DGLDeviceType
XPU
,
typename
IdxType
,
typename
DType
,
bool
map_seed_nodes
>
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
(
CSRMatrix
mat
,
IdArray
rows
,
IdArray
seed_mapping
,
std
::
vector
<
IdxType
>*
new_seed_nodes
,
int64_t
num_samples
,
NDArray
prob_or_mask
,
bool
replace
)
{
// If num_samples is -1, select all neighbors without replacement.
replace
=
(
replace
&&
num_samples
!=
-
1
);
CHECK
(
prob_or_mask
.
defined
());
auto
num_picks_fn
=
GetSamplingNumPicksFn
<
IdxType
,
DType
>
(
num_samples
,
prob_or_mask
,
replace
);
auto
pick_fn
=
GetSamplingPickFn
<
IdxType
,
DType
>
(
num_samples
,
prob_or_mask
,
replace
);
return
CSRRowWisePickFused
<
IdxType
,
map_seed_nodes
>
(
mat
,
rows
,
seed_mapping
,
new_seed_nodes
,
num_samples
,
replace
,
pick_fn
,
num_picks_fn
);
}
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int32_t
,
float
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int64_t
,
float
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int32_t
,
double
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int64_t
,
double
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int32_t
,
int8_t
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int64_t
,
int8_t
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int32_t
,
uint8_t
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int64_t
,
uint8_t
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int32_t
,
float
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int64_t
,
float
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int32_t
,
double
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int64_t
,
double
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int32_t
,
int8_t
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int64_t
,
int8_t
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int32_t
,
uint8_t
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
NDArray
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingFused
<
kDGLCPU
,
int64_t
,
uint8_t
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
NDArray
,
bool
);
template
<
DGLDeviceType
XPU
,
typename
IdxType
,
typename
DType
>
COOMatrix
CSRRowWisePerEtypeSampling
(
CSRMatrix
mat
,
IdArray
rows
,
const
std
::
vector
<
int64_t
>&
eid2etype_offset
,
...
...
@@ -283,6 +351,33 @@ template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
template
COOMatrix
CSRRowWiseSamplingUniform
<
kDGLCPU
,
int64_t
>(
CSRMatrix
,
IdArray
,
int64_t
,
bool
);
template
<
DGLDeviceType
XPU
,
typename
IdxType
,
bool
map_seed_nodes
>
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingUniformFused
(
CSRMatrix
mat
,
IdArray
rows
,
IdArray
seed_mapping
,
std
::
vector
<
IdxType
>*
new_seed_nodes
,
int64_t
num_samples
,
bool
replace
)
{
// If num_samples is -1, select all neighbors without replacement.
replace
=
(
replace
&&
num_samples
!=
-
1
);
auto
num_picks_fn
=
GetSamplingUniformNumPicksFn
<
IdxType
>
(
num_samples
,
replace
);
auto
pick_fn
=
GetSamplingUniformPickFn
<
IdxType
>
(
num_samples
,
replace
);
return
CSRRowWisePickFused
<
IdxType
,
map_seed_nodes
>
(
mat
,
rows
,
seed_mapping
,
new_seed_nodes
,
num_samples
,
replace
,
pick_fn
,
num_picks_fn
);
}
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingUniformFused
<
kDGLCPU
,
int32_t
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingUniformFused
<
kDGLCPU
,
int64_t
,
true
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingUniformFused
<
kDGLCPU
,
int32_t
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int32_t
>*
,
int64_t
,
bool
);
template
std
::
pair
<
CSRMatrix
,
IdArray
>
CSRRowWiseSamplingUniformFused
<
kDGLCPU
,
int64_t
,
false
>
(
CSRMatrix
,
IdArray
,
IdArray
,
std
::
vector
<
int64_t
>*
,
int64_t
,
bool
);
template
<
DGLDeviceType
XPU
,
typename
IdxType
>
COOMatrix
CSRRowWisePerEtypeSamplingUniform
(
CSRMatrix
mat
,
IdArray
rows
,
const
std
::
vector
<
int64_t
>&
eid2etype_offset
,
...
...
src/graph/sampling/neighbor/neighbor.cc
View file @
4135b1bd
...
...
@@ -6,13 +6,16 @@
#include <dgl/array.h>
#include <dgl/aten/macro.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/sampling/neighbor.h>
#include <tuple>
#include <utility>
#include "../../../array/cpu/concurrent_id_hash_map.h"
#include "../../../c_api_common.h"
#include "../../unit_graph.h"
...
...
@@ -22,6 +25,76 @@ using namespace dgl::aten;
namespace
dgl
{
namespace
sampling
{
template
<
typename
IdType
>
void
ExcludeCertainEdgesFused
(
std
::
vector
<
CSRMatrix
>*
sampled_graphs
,
std
::
vector
<
IdArray
>*
induced_edges
,
std
::
vector
<
IdArray
>*
sampled_coo_rows
,
const
std
::
vector
<
IdArray
>&
exclude_edges
,
std
::
vector
<
FloatArray
>*
weights
=
nullptr
)
{
int
etypes
=
(
*
sampled_graphs
).
size
();
std
::
vector
<
IdArray
>
remain_induced_edges
(
etypes
);
std
::
vector
<
IdArray
>
remain_indptrs
(
etypes
);
std
::
vector
<
IdArray
>
remain_indices
(
etypes
);
std
::
vector
<
IdArray
>
remain_coo_rows
(
etypes
);
std
::
vector
<
FloatArray
>
remain_weights
(
etypes
);
for
(
int
etype
=
0
;
etype
<
etypes
;
++
etype
)
{
if
(
exclude_edges
[
etype
].
GetSize
()
==
0
||
(
*
sampled_graphs
)[
etype
].
num_rows
==
0
)
{
remain_induced_edges
[
etype
]
=
(
*
induced_edges
)[
etype
];
if
(
weights
)
remain_weights
[
etype
]
=
(
*
weights
)[
etype
];
continue
;
}
const
auto
dtype
=
weights
&&
(
*
weights
)[
etype
]
->
shape
[
0
]
?
(
*
weights
)[
etype
]
->
dtype
:
DGLDataType
{
kDGLFloat
,
8
*
sizeof
(
float
),
1
};
ATEN_FLOAT_TYPE_SWITCH
(
dtype
,
FloatType
,
"weights"
,
{
IdType
*
indptr
=
(
*
sampled_graphs
)[
etype
].
indptr
.
Ptr
<
IdType
>
();
IdType
*
indices
=
(
*
sampled_graphs
)[
etype
].
indices
.
Ptr
<
IdType
>
();
IdType
*
coo_rows
=
(
*
sampled_coo_rows
)[
etype
].
Ptr
<
IdType
>
();
IdType
*
induced_edges_data
=
(
*
induced_edges
)[
etype
].
Ptr
<
IdType
>
();
FloatType
*
weights_data
=
weights
&&
(
*
weights
)[
etype
]
->
shape
[
0
]
?
(
*
weights
)[
etype
].
Ptr
<
FloatType
>
()
:
nullptr
;
const
IdType
exclude_edges_len
=
exclude_edges
[
etype
]
->
shape
[
0
];
std
::
sort
(
exclude_edges
[
etype
].
Ptr
<
IdType
>
(),
exclude_edges
[
etype
].
Ptr
<
IdType
>
()
+
exclude_edges_len
);
const
IdType
*
exclude_edges_data
=
exclude_edges
[
etype
].
Ptr
<
IdType
>
();
IdType
outIndices
=
0
;
for
(
IdType
row
=
0
;
row
<
(
*
sampled_graphs
)[
etype
].
indptr
->
shape
[
0
]
-
1
;
++
row
)
{
auto
tmp_row
=
indptr
[
row
];
if
(
outIndices
!=
indptr
[
row
])
indptr
[
row
]
=
outIndices
;
for
(
IdType
col
=
tmp_row
;
col
<
indptr
[
row
+
1
];
++
col
)
{
if
(
!
std
::
binary_search
(
exclude_edges_data
,
exclude_edges_data
+
exclude_edges_len
,
induced_edges_data
[
col
]))
{
indices
[
outIndices
]
=
indices
[
col
];
induced_edges_data
[
outIndices
]
=
induced_edges_data
[
col
];
coo_rows
[
outIndices
]
=
coo_rows
[
col
];
if
(
weights_data
)
weights_data
[
outIndices
]
=
weights_data
[
col
];
++
outIndices
;
}
}
}
indptr
[(
*
sampled_graphs
)[
etype
].
indptr
->
shape
[
0
]
-
1
]
=
outIndices
;
remain_induced_edges
[
etype
]
=
aten
::
IndexSelect
((
*
induced_edges
)[
etype
],
0
,
outIndices
);
remain_weights
[
etype
]
=
weights_data
?
aten
::
IndexSelect
((
*
weights
)[
etype
],
0
,
outIndices
)
:
NullArray
();
remain_indices
[
etype
]
=
aten
::
IndexSelect
((
*
sampled_graphs
)[
etype
].
indices
,
0
,
outIndices
);
(
*
sampled_coo_rows
)[
etype
]
=
aten
::
IndexSelect
((
*
sampled_coo_rows
)[
etype
],
0
,
outIndices
);
(
*
sampled_graphs
)[
etype
]
=
CSRMatrix
(
(
*
sampled_graphs
)[
etype
].
num_rows
,
outIndices
,
(
*
sampled_graphs
)[
etype
].
indptr
,
remain_indices
[
etype
],
remain_induced_edges
[
etype
]);
});
}
}
std
::
pair
<
HeteroSubgraph
,
std
::
vector
<
FloatArray
>>
ExcludeCertainEdges
(
const
HeteroSubgraph
&
sg
,
const
std
::
vector
<
IdArray
>&
exclude_edges
,
const
std
::
vector
<
FloatArray
>*
weights
=
nullptr
)
{
...
...
@@ -266,6 +339,242 @@ HeteroSubgraph SampleNeighbors(
return
ret
;
}
template
<
typename
IdType
>
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>
,
std
::
vector
<
IdArray
>>
SampleNeighborsFused
(
const
HeteroGraphPtr
hg
,
const
std
::
vector
<
IdArray
>&
nodes
,
const
std
::
vector
<
IdArray
>&
mapping
,
const
std
::
vector
<
int64_t
>&
fanouts
,
EdgeDir
dir
,
const
std
::
vector
<
NDArray
>&
prob_or_mask
,
const
std
::
vector
<
IdArray
>&
exclude_edges
,
bool
replace
)
{
CHECK_EQ
(
nodes
.
size
(),
hg
->
NumVertexTypes
())
<<
"Number of node ID tensors must match the number of node types."
;
CHECK_EQ
(
fanouts
.
size
(),
hg
->
NumEdgeTypes
())
<<
"Number of fanout values must match the number of edge types."
;
CHECK_EQ
(
prob_or_mask
.
size
(),
hg
->
NumEdgeTypes
())
<<
"Number of probability tensors must match the number of edge types."
;
DGLContext
ctx
=
aten
::
GetContextOf
(
nodes
);
std
::
vector
<
CSRMatrix
>
sampled_graphs
;
std
::
vector
<
IdArray
>
sampled_coo_rows
;
std
::
vector
<
IdArray
>
induced_edges
;
std
::
vector
<
IdArray
>
induced_vertices
;
std
::
vector
<
int64_t
>
num_nodes_per_type
;
std
::
vector
<
std
::
vector
<
IdType
>>
new_nodes_vec
(
hg
->
NumVertexTypes
());
std
::
vector
<
int
>
seed_nodes_mapped
(
hg
->
NumVertexTypes
(),
0
);
for
(
dgl_type_t
etype
=
0
;
etype
<
hg
->
NumEdgeTypes
();
++
etype
)
{
auto
pair
=
hg
->
meta_graph
()
->
FindEdge
(
etype
);
const
dgl_type_t
src_vtype
=
pair
.
first
;
const
dgl_type_t
dst_vtype
=
pair
.
second
;
const
dgl_type_t
rhs_node_type
=
(
dir
==
EdgeDir
::
kOut
)
?
src_vtype
:
dst_vtype
;
const
IdArray
nodes_ntype
=
nodes
[
rhs_node_type
];
const
int64_t
num_nodes
=
nodes_ntype
->
shape
[
0
];
if
(
num_nodes
==
0
||
fanouts
[
etype
]
==
0
)
{
// Nothing to sample for this etype, create a placeholder
sampled_graphs
.
push_back
(
CSRMatrix
());
sampled_coo_rows
.
push_back
(
IdArray
());
induced_edges
.
push_back
(
aten
::
NullArray
(
hg
->
DataType
(),
ctx
));
}
else
{
bool
map_seed_nodes
=
!
seed_nodes_mapped
[
rhs_node_type
];
// sample from one relation graph
std
::
pair
<
CSRMatrix
,
IdArray
>
sampled_graph
;
auto
sampling_fn
=
map_seed_nodes
?
aten
::
CSRRowWiseSamplingFused
<
IdType
,
true
>
:
aten
::
CSRRowWiseSamplingFused
<
IdType
,
false
>
;
auto
req_fmt
=
(
dir
==
EdgeDir
::
kOut
)
?
CSR_CODE
:
CSC_CODE
;
auto
avail_fmt
=
hg
->
SelectFormat
(
etype
,
req_fmt
);
switch
(
avail_fmt
)
{
case
SparseFormat
::
kCSR
:
CHECK
(
dir
==
EdgeDir
::
kOut
)
<<
"Cannot sample out edges on CSC matrix."
;
// In heterographs nodes of two diffrent types can be connected
// therefore two diffrent mappings and node vectors are needed
sampled_graph
=
sampling_fn
(
hg
->
GetCSRMatrix
(
etype
),
nodes_ntype
,
mapping
[
src_vtype
],
&
new_nodes_vec
[
src_vtype
],
fanouts
[
etype
],
prob_or_mask
[
etype
],
replace
);
break
;
case
SparseFormat
::
kCSC
:
CHECK
(
dir
==
EdgeDir
::
kIn
)
<<
"Cannot sample in edges on CSR matrix."
;
sampled_graph
=
sampling_fn
(
hg
->
GetCSCMatrix
(
etype
),
nodes_ntype
,
mapping
[
dst_vtype
],
&
new_nodes_vec
[
dst_vtype
],
fanouts
[
etype
],
prob_or_mask
[
etype
],
replace
);
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported sparse format."
;
}
seed_nodes_mapped
[
rhs_node_type
]
++
;
sampled_graphs
.
push_back
(
sampled_graph
.
first
);
if
(
sampled_graph
.
first
.
data
.
defined
())
induced_edges
.
push_back
(
sampled_graph
.
first
.
data
);
else
induced_edges
.
push_back
(
aten
::
NullArray
(
DGLDataType
{
kDGLInt
,
sizeof
(
IdType
)
*
8
,
1
},
ctx
));
sampled_coo_rows
.
push_back
(
sampled_graph
.
second
);
}
}
if
(
!
exclude_edges
.
empty
())
{
ExcludeCertainEdgesFused
<
IdType
>
(
&
sampled_graphs
,
&
induced_edges
,
&
sampled_coo_rows
,
exclude_edges
);
for
(
size_t
i
=
0
;
i
<
hg
->
NumEdgeTypes
();
i
++
)
{
if
(
sampled_graphs
[
i
].
data
.
defined
())
induced_edges
[
i
]
=
std
::
move
(
sampled_graphs
[
i
].
data
);
else
induced_edges
[
i
]
=
aten
::
NullArray
(
DGLDataType
{
kDGLInt
,
sizeof
(
IdType
)
*
8
,
1
},
ctx
);
}
}
// map indices
for
(
dgl_type_t
etype
=
0
;
etype
<
hg
->
NumEdgeTypes
();
++
etype
)
{
auto
pair
=
hg
->
meta_graph
()
->
FindEdge
(
etype
);
const
dgl_type_t
src_vtype
=
pair
.
first
;
const
dgl_type_t
dst_vtype
=
pair
.
second
;
const
dgl_type_t
lhs_node_type
=
(
dir
==
EdgeDir
::
kIn
)
?
src_vtype
:
dst_vtype
;
if
(
sampled_graphs
[
etype
].
num_cols
!=
0
)
{
auto
num_cols
=
sampled_graphs
[
etype
].
num_cols
;
int
num_threads_col
=
runtime
::
compute_num_threads
(
0
,
num_cols
,
1
);
std
::
vector
<
IdType
>
global_prefix_col
(
num_threads_col
+
1
,
0
);
std
::
vector
<
std
::
vector
<
IdType
>>
src_nodes_local
(
num_threads_col
);
IdType
*
mapping_data_dst
=
mapping
[
lhs_node_type
].
Ptr
<
IdType
>
();
IdType
*
cdata
=
sampled_graphs
[
etype
].
indices
.
Ptr
<
IdType
>
();
#pragma omp parallel num_threads(num_threads_col)
{
const
int
thread_id
=
omp_get_thread_num
();
num_threads_col
=
omp_get_num_threads
();
const
int64_t
start_i
=
thread_id
*
(
num_cols
/
num_threads_col
)
+
std
::
min
(
static_cast
<
int64_t
>
(
thread_id
),
num_cols
%
num_threads_col
);
const
int64_t
end_i
=
(
thread_id
+
1
)
*
(
num_cols
/
num_threads_col
)
+
std
::
min
(
static_cast
<
int64_t
>
(
thread_id
+
1
),
num_cols
%
num_threads_col
);
assert
(
thread_id
+
1
<
num_threads_col
||
end_i
==
num_cols
);
for
(
int64_t
i
=
start_i
;
i
<
end_i
;
++
i
)
{
int64_t
picked_idx
=
cdata
[
i
];
bool
spot_claimed
=
BoolCompareAndSwap
<
IdType
>
(
&
mapping_data_dst
[
picked_idx
]);
if
(
spot_claimed
)
src_nodes_local
[
thread_id
].
push_back
(
picked_idx
);
}
global_prefix_col
[
thread_id
+
1
]
=
src_nodes_local
[
thread_id
].
size
();
#pragma omp barrier
#pragma omp master
{
global_prefix_col
[
0
]
=
new_nodes_vec
[
lhs_node_type
].
size
();
for
(
int
t
=
0
;
t
<
num_threads_col
;
++
t
)
{
global_prefix_col
[
t
+
1
]
+=
global_prefix_col
[
t
];
}
}
#pragma omp barrier
int64_t
mapping_shift
=
global_prefix_col
[
thread_id
];
for
(
size_t
i
=
0
;
i
<
src_nodes_local
[
thread_id
].
size
();
++
i
)
mapping_data_dst
[
src_nodes_local
[
thread_id
][
i
]]
=
mapping_shift
+
i
;
#pragma omp barrier
for
(
int64_t
i
=
start_i
;
i
<
end_i
;
++
i
)
{
IdType
picked_idx
=
cdata
[
i
];
IdType
mapped_idx
=
mapping_data_dst
[
picked_idx
];
cdata
[
i
]
=
mapped_idx
;
}
}
IdType
offset
=
new_nodes_vec
[
lhs_node_type
].
size
();
new_nodes_vec
[
lhs_node_type
].
resize
(
global_prefix_col
.
back
());
for
(
int
thread_id
=
0
;
thread_id
<
num_threads_col
;
++
thread_id
)
{
memcpy
(
new_nodes_vec
[
lhs_node_type
].
data
()
+
offset
,
&
src_nodes_local
[
thread_id
][
0
],
src_nodes_local
[
thread_id
].
size
()
*
sizeof
(
IdType
));
offset
+=
src_nodes_local
[
thread_id
].
size
();
}
}
}
// counting how many nodes of each ntype were sampled
num_nodes_per_type
.
resize
(
2
*
hg
->
NumVertexTypes
());
for
(
size_t
i
=
0
;
i
<
hg
->
NumVertexTypes
();
i
++
)
{
num_nodes_per_type
[
i
]
=
new_nodes_vec
[
i
].
size
();
num_nodes_per_type
[
hg
->
NumVertexTypes
()
+
i
]
=
nodes
[
i
]
->
shape
[
0
];
induced_vertices
.
push_back
(
VecToIdArray
(
new_nodes_vec
[
i
],
sizeof
(
IdType
)
*
8
));
}
std
::
vector
<
HeteroGraphPtr
>
subrels
(
hg
->
NumEdgeTypes
());
for
(
dgl_type_t
etype
=
0
;
etype
<
hg
->
NumEdgeTypes
();
++
etype
)
{
auto
pair
=
hg
->
meta_graph
()
->
FindEdge
(
etype
);
const
dgl_type_t
src_vtype
=
pair
.
first
;
const
dgl_type_t
dst_vtype
=
pair
.
second
;
if
(
sampled_graphs
[
etype
].
num_rows
==
0
)
{
subrels
[
etype
]
=
UnitGraph
::
Empty
(
2
,
new_nodes_vec
[
src_vtype
].
size
(),
nodes
[
dst_vtype
]
->
shape
[
0
],
hg
->
DataType
(),
ctx
);
}
else
{
CSRMatrix
graph
=
sampled_graphs
[
etype
];
if
(
dir
==
EdgeDir
::
kOut
)
{
subrels
[
etype
]
=
UnitGraph
::
CreateFromCSRAndCOO
(
2
,
CSRMatrix
(
nodes
[
src_vtype
]
->
shape
[
0
],
new_nodes_vec
[
dst_vtype
].
size
(),
graph
.
indptr
,
graph
.
indices
,
Range
(
0
,
graph
.
indices
->
shape
[
0
],
graph
.
indices
->
dtype
.
bits
,
ctx
)),
COOMatrix
(
nodes
[
src_vtype
]
->
shape
[
0
],
new_nodes_vec
[
dst_vtype
].
size
(),
sampled_coo_rows
[
etype
],
graph
.
indices
),
ALL_CODE
);
}
else
{
subrels
[
etype
]
=
UnitGraph
::
CreateFromCSCAndCOO
(
2
,
CSRMatrix
(
nodes
[
dst_vtype
]
->
shape
[
0
],
new_nodes_vec
[
src_vtype
].
size
(),
graph
.
indptr
,
graph
.
indices
,
Range
(
0
,
graph
.
indices
->
shape
[
0
],
graph
.
indices
->
dtype
.
bits
,
ctx
)),
COOMatrix
(
new_nodes_vec
[
src_vtype
].
size
(),
nodes
[
dst_vtype
]
->
shape
[
0
],
graph
.
indices
,
sampled_coo_rows
[
etype
]),
ALL_CODE
);
}
}
}
HeteroSubgraph
ret
;
const
auto
meta_graph
=
hg
->
meta_graph
();
const
EdgeArray
etypes
=
meta_graph
->
Edges
(
"eid"
);
const
IdArray
new_dst
=
Add
(
etypes
.
dst
,
hg
->
NumVertexTypes
());
const
auto
new_meta_graph
=
ImmutableGraph
::
CreateFromCOO
(
hg
->
NumVertexTypes
()
*
2
,
etypes
.
src
,
new_dst
);
HeteroGraphPtr
new_graph
=
CreateHeteroGraph
(
new_meta_graph
,
subrels
,
num_nodes_per_type
);
return
std
::
make_tuple
(
new_graph
,
induced_edges
,
induced_vertices
);
}
template
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>,
std
::
vector
<
IdArray
>>
SampleNeighborsFused
<
int64_t
>
(
const
HeteroGraphPtr
,
const
std
::
vector
<
IdArray
>&
,
const
std
::
vector
<
IdArray
>&
,
const
std
::
vector
<
int64_t
>&
,
EdgeDir
,
const
std
::
vector
<
NDArray
>&
,
const
std
::
vector
<
IdArray
>&
,
bool
);
template
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>,
std
::
vector
<
IdArray
>>
SampleNeighborsFused
<
int32_t
>
(
const
HeteroGraphPtr
,
const
std
::
vector
<
IdArray
>&
,
const
std
::
vector
<
IdArray
>&
,
const
std
::
vector
<
int64_t
>&
,
EdgeDir
,
const
std
::
vector
<
NDArray
>&
,
const
std
::
vector
<
IdArray
>&
,
bool
);
HeteroSubgraph
SampleNeighborsEType
(
const
HeteroGraphPtr
hg
,
const
IdArray
nodes
,
const
std
::
vector
<
int64_t
>&
eid2etype_offset
,
...
...
@@ -568,6 +877,47 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
*
rv
=
HeteroSubgraphRef
(
subg
);
});
DGL_REGISTER_GLOBAL
(
"sampling.neighbor._CAPI_DGLSampleNeighborsFused"
)
.
set_body
([](
DGLArgs
args
,
DGLRetValue
*
rv
)
{
HeteroGraphRef
hg
=
args
[
0
];
const
auto
&
nodes
=
ListValueToVector
<
IdArray
>
(
args
[
1
]);
auto
mapping
=
ListValueToVector
<
IdArray
>
(
args
[
2
]);
IdArray
fanouts_array
=
args
[
3
];
const
auto
&
fanouts
=
fanouts_array
.
ToVector
<
int64_t
>
();
const
std
::
string
dir_str
=
args
[
4
];
const
auto
&
prob_or_mask
=
ListValueToVector
<
NDArray
>
(
args
[
5
]);
const
auto
&
exclude_edges
=
ListValueToVector
<
IdArray
>
(
args
[
6
]);
const
bool
replace
=
args
[
7
];
CHECK
(
dir_str
==
"in"
||
dir_str
==
"out"
)
<<
"Invalid edge direction. Must be
\"
in
\"
or
\"
out
\"
."
;
EdgeDir
dir
=
(
dir_str
==
"in"
)
?
EdgeDir
::
kIn
:
EdgeDir
::
kOut
;
HeteroGraphPtr
new_graph
;
std
::
vector
<
IdArray
>
induced_edges
;
std
::
vector
<
IdArray
>
induced_vertices
;
ATEN_ID_TYPE_SWITCH
(
hg
->
DataType
(),
IdType
,
{
std
::
tie
(
new_graph
,
induced_edges
,
induced_vertices
)
=
SampleNeighborsFused
<
IdType
>
(
hg
.
sptr
(),
nodes
,
mapping
,
fanouts
,
dir
,
prob_or_mask
,
exclude_edges
,
replace
);
});
List
<
Value
>
lhs_nodes_ref
;
for
(
IdArray
&
array
:
induced_vertices
)
lhs_nodes_ref
.
push_back
(
Value
(
MakeValue
(
array
)));
List
<
Value
>
induced_edges_ref
;
for
(
IdArray
&
array
:
induced_edges
)
induced_edges_ref
.
push_back
(
Value
(
MakeValue
(
array
)));
List
<
ObjectRef
>
ret
;
ret
.
push_back
(
HeteroGraphRef
(
new_graph
));
ret
.
push_back
(
lhs_nodes_ref
);
ret
.
push_back
(
induced_edges_ref
);
*
rv
=
ret
;
});
DGL_REGISTER_GLOBAL
(
"sampling.neighbor._CAPI_DGLSampleNeighborsTopk"
)
.
set_body
([](
DGLArgs
args
,
DGLRetValue
*
rv
)
{
HeteroGraphRef
hg
=
args
[
0
];
...
...
src/graph/unit_graph.cc
View file @
4135b1bd
...
...
@@ -1218,6 +1218,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSR(
return
HeteroGraphPtr
(
new
UnitGraph
(
mg
,
nullptr
,
csr
,
nullptr
,
formats
));
}
HeteroGraphPtr
UnitGraph
::
CreateFromCSRAndCOO
(
int64_t
num_vtypes
,
const
aten
::
CSRMatrix
&
csr
,
const
aten
::
COOMatrix
&
coo
,
dgl_format_code_t
formats
)
{
CHECK
(
num_vtypes
==
1
||
num_vtypes
==
2
);
CHECK_EQ
(
coo
.
num_rows
,
csr
.
num_rows
);
CHECK_EQ
(
coo
.
num_cols
,
csr
.
num_cols
);
if
(
num_vtypes
==
1
)
{
CHECK_EQ
(
csr
.
num_rows
,
csr
.
num_cols
);
}
auto
mg
=
CreateUnitGraphMetaGraph
(
num_vtypes
);
CSRPtr
csrPtr
(
new
CSR
(
mg
,
csr
));
COOPtr
cooPtr
(
new
COO
(
mg
,
coo
));
return
HeteroGraphPtr
(
new
UnitGraph
(
mg
,
nullptr
,
csrPtr
,
cooPtr
,
formats
));
}
HeteroGraphPtr
UnitGraph
::
CreateFromCSC
(
int64_t
num_vtypes
,
int64_t
num_src
,
int64_t
num_dst
,
IdArray
indptr
,
IdArray
indices
,
IdArray
edge_ids
,
dgl_format_code_t
formats
)
{
...
...
@@ -1237,6 +1252,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSC(
return
HeteroGraphPtr
(
new
UnitGraph
(
mg
,
csc
,
nullptr
,
nullptr
,
formats
));
}
HeteroGraphPtr
UnitGraph
::
CreateFromCSCAndCOO
(
int64_t
num_vtypes
,
const
aten
::
CSRMatrix
&
csc
,
const
aten
::
COOMatrix
&
coo
,
dgl_format_code_t
formats
)
{
CHECK
(
num_vtypes
==
1
||
num_vtypes
==
2
);
CHECK_EQ
(
coo
.
num_rows
,
csc
.
num_cols
);
CHECK_EQ
(
coo
.
num_cols
,
csc
.
num_rows
);
if
(
num_vtypes
==
1
)
{
CHECK_EQ
(
csc
.
num_rows
,
csc
.
num_cols
);
}
auto
mg
=
CreateUnitGraphMetaGraph
(
num_vtypes
);
CSRPtr
cscPtr
(
new
CSR
(
mg
,
csc
));
COOPtr
cooPtr
(
new
COO
(
mg
,
coo
));
return
HeteroGraphPtr
(
new
UnitGraph
(
mg
,
cscPtr
,
nullptr
,
cooPtr
,
formats
));
}
HeteroGraphPtr
UnitGraph
::
AsNumBits
(
HeteroGraphPtr
g
,
uint8_t
bits
)
{
if
(
g
->
NumBits
()
==
bits
)
{
return
g
;
...
...
src/graph/unit_graph.h
View file @
4135b1bd
...
...
@@ -190,6 +190,12 @@ class UnitGraph : public BaseHeteroGraph {
int64_t
num_vtypes
,
const
aten
::
CSRMatrix
&
mat
,
dgl_format_code_t
formats
=
ALL_CODE
);
/** @brief Create a graph from (out) CSR and COO arrays, both representing the
* same graph */
static
HeteroGraphPtr
CreateFromCSRAndCOO
(
int64_t
num_vtypes
,
const
aten
::
CSRMatrix
&
csr
,
const
aten
::
COOMatrix
&
coo
,
dgl_format_code_t
formats
=
ALL_CODE
);
/** @brief Create a graph from (in) CSC arrays */
static
HeteroGraphPtr
CreateFromCSC
(
int64_t
num_vtypes
,
int64_t
num_src
,
int64_t
num_dst
,
IdArray
indptr
,
...
...
@@ -199,6 +205,12 @@ class UnitGraph : public BaseHeteroGraph {
int64_t
num_vtypes
,
const
aten
::
CSRMatrix
&
mat
,
dgl_format_code_t
formats
=
ALL_CODE
);
/** @brief Create a graph from (in) CSC and COO arrays, both representing the
* same graph */
static
HeteroGraphPtr
CreateFromCSCAndCOO
(
int64_t
num_vtypes
,
const
aten
::
CSRMatrix
&
csc
,
const
aten
::
COOMatrix
&
coo
,
dgl_format_code_t
formats
=
ALL_CODE
);
/** @brief Convert the graph to use the given number of bits for storage */
static
HeteroGraphPtr
AsNumBits
(
HeteroGraphPtr
g
,
uint8_t
bits
);
...
...
tests/python/common/sampling/test_sampling.py
View file @
4135b1bd
...
...
@@ -7,6 +7,11 @@ import dgl
import
numpy
as
np
import
pytest
sample_neighbors_fusing_mode
=
{
True
:
dgl
.
sampling
.
sample_neighbors_fused
,
False
:
dgl
.
sampling
.
sample_neighbors
,
}
def
check_random_walk
(
g
,
metapath
,
traces
,
ntypes
,
prob
=
None
,
trace_eids
=
None
):
traces
=
F
.
asnumpy
(
traces
)
...
...
@@ -555,15 +560,18 @@ def _gen_neighbor_topk_test_graph(hypersparse, reverse):
return
g
,
hg
def
_test_sample_neighbors
(
hypersparse
,
prob
):
def
_test_sample_neighbors
(
hypersparse
,
prob
,
fused
):
g
,
hg
=
_gen_neighbor_sampling_test_graph
(
hypersparse
,
False
)
def
_test1
(
p
,
replace
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
[
0
,
1
],
-
1
,
prob
=
p
,
replace
=
replace
)
if
not
fused
:
assert
subg
.
num_nodes
()
==
g
.
num_nodes
()
u
,
v
=
subg
.
edges
()
if
fused
:
u
,
v
=
subg
.
srcdata
[
dgl
.
NID
][
u
],
subg
.
dstdata
[
dgl
.
NID
][
v
]
u_ans
,
v_ans
,
e_ans
=
g
.
in_edges
([
0
,
1
],
form
=
"all"
)
if
p
is
not
None
:
emask
=
F
.
gather_row
(
g
.
edata
[
p
],
e_ans
)
...
...
@@ -576,12 +584,17 @@ def _test_sample_neighbors(hypersparse, prob):
assert
uv
==
uv_ans
for
i
in
range
(
10
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
[
0
,
1
],
2
,
prob
=
p
,
replace
=
replace
)
if
not
fused
:
assert
subg
.
num_nodes
()
==
g
.
num_nodes
()
assert
subg
.
num_edges
()
==
4
u
,
v
=
subg
.
edges
()
if
fused
:
u
,
v
=
subg
.
srcdata
[
dgl
.
NID
][
u
],
subg
.
dstdata
[
dgl
.
NID
][
v
]
assert
set
(
F
.
asnumpy
(
F
.
unique
(
v
)))
==
{
0
,
1
}
assert
F
.
array_equal
(
F
.
astype
(
g
.
has_edges_between
(
u
,
v
),
F
.
int64
),
...
...
@@ -600,11 +613,14 @@ def _test_sample_neighbors(hypersparse, prob):
_test1
(
prob
,
False
)
# w/o replacement, uniform
def
_test2
(
p
,
replace
):
# fanout > #neighbors
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
[
0
,
2
],
-
1
,
prob
=
p
,
replace
=
replace
)
if
not
fused
:
assert
subg
.
num_nodes
()
==
g
.
num_nodes
()
u
,
v
=
subg
.
edges
()
if
fused
:
u
,
v
=
subg
.
srcdata
[
dgl
.
NID
][
u
],
subg
.
dstdata
[
dgl
.
NID
][
v
]
u_ans
,
v_ans
,
e_ans
=
g
.
in_edges
([
0
,
2
],
form
=
"all"
)
if
p
is
not
None
:
emask
=
F
.
gather_row
(
g
.
edata
[
p
],
e_ans
)
...
...
@@ -617,13 +633,16 @@ def _test_sample_neighbors(hypersparse, prob):
assert
uv
==
uv_ans
for
i
in
range
(
10
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
[
0
,
2
],
2
,
prob
=
p
,
replace
=
replace
)
if
not
fused
:
assert
subg
.
num_nodes
()
==
g
.
num_nodes
()
num_edges
=
4
if
replace
else
3
assert
subg
.
num_edges
()
==
num_edges
u
,
v
=
subg
.
edges
()
if
fused
:
u
,
v
=
subg
.
srcdata
[
dgl
.
NID
][
u
],
subg
.
dstdata
[
dgl
.
NID
][
v
]
assert
set
(
F
.
asnumpy
(
F
.
unique
(
v
)))
==
{
0
,
2
}
assert
F
.
array_equal
(
F
.
astype
(
g
.
has_edges_between
(
u
,
v
),
F
.
int64
),
...
...
@@ -641,10 +660,13 @@ def _test_sample_neighbors(hypersparse, prob):
_test2
(
prob
,
False
)
# w/o replacement, uniform
def
_test3
(
p
,
replace
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
hg
,
{
"user"
:
[
0
,
1
],
"game"
:
0
},
-
1
,
prob
=
p
,
replace
=
replace
)
if
not
fused
:
assert
len
(
subg
.
ntypes
)
==
3
assert
len
(
subg
.
srctypes
)
==
3
assert
len
(
subg
.
dsttypes
)
==
3
assert
len
(
subg
.
etypes
)
==
4
assert
subg
[
"follow"
].
num_edges
()
==
6
if
p
is
None
else
4
assert
subg
[
"play"
].
num_edges
()
==
1
...
...
@@ -652,10 +674,13 @@ def _test_sample_neighbors(hypersparse, prob):
assert
subg
[
"flips"
].
num_edges
()
==
0
for
i
in
range
(
10
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
hg
,
{
"user"
:
[
0
,
1
],
"game"
:
0
},
2
,
prob
=
p
,
replace
=
replace
)
if
not
fused
:
assert
len
(
subg
.
ntypes
)
==
3
assert
len
(
subg
.
srctypes
)
==
3
assert
len
(
subg
.
dsttypes
)
==
3
assert
len
(
subg
.
etypes
)
==
4
assert
subg
[
"follow"
].
num_edges
()
==
4
assert
subg
[
"play"
].
num_edges
()
==
2
if
replace
else
1
...
...
@@ -667,13 +692,16 @@ def _test_sample_neighbors(hypersparse, prob):
# test different fanouts for different relations
for
i
in
range
(
10
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
hg
,
{
"user"
:
[
0
,
1
],
"game"
:
0
,
"coin"
:
0
},
{
"follow"
:
1
,
"play"
:
2
,
"liked-by"
:
0
,
"flips"
:
-
1
},
replace
=
True
,
)
if
not
fused
:
assert
len
(
subg
.
ntypes
)
==
3
assert
len
(
subg
.
srctypes
)
==
3
assert
len
(
subg
.
dsttypes
)
==
3
assert
len
(
subg
.
etypes
)
==
4
assert
subg
[
"follow"
].
num_edges
()
==
2
assert
subg
[
"play"
].
num_edges
()
==
2
...
...
@@ -795,15 +823,19 @@ def _test_sample_labors(hypersparse, prob):
assert
subg
[
"flips"
].
num_edges
()
==
4
def
_test_sample_neighbors_outedge
(
hypersparse
):
def
_test_sample_neighbors_outedge
(
hypersparse
,
fused
):
g
,
hg
=
_gen_neighbor_sampling_test_graph
(
hypersparse
,
True
)
def
_test1
(
p
,
replace
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
[
0
,
1
],
-
1
,
prob
=
p
,
replace
=
replace
,
edge_dir
=
"out"
)
if
not
fused
:
assert
subg
.
num_nodes
()
==
g
.
num_nodes
()
u
,
v
=
subg
.
edges
()
if
fused
:
u
,
v
=
subg
.
dstdata
[
dgl
.
NID
][
u
],
subg
.
srcdata
[
dgl
.
NID
][
v
]
u_ans
,
v_ans
,
e_ans
=
g
.
out_edges
([
0
,
1
],
form
=
"all"
)
if
p
is
not
None
:
emask
=
F
.
gather_row
(
g
.
edata
[
p
],
e_ans
)
...
...
@@ -816,12 +848,15 @@ def _test_sample_neighbors_outedge(hypersparse):
assert
uv
==
uv_ans
for
i
in
range
(
10
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
[
0
,
1
],
2
,
prob
=
p
,
replace
=
replace
,
edge_dir
=
"out"
)
if
not
fused
:
assert
subg
.
num_nodes
()
==
g
.
num_nodes
()
assert
subg
.
num_edges
()
==
4
u
,
v
=
subg
.
edges
()
if
fused
:
u
,
v
=
subg
.
dstdata
[
dgl
.
NID
][
u
],
subg
.
srcdata
[
dgl
.
NID
][
v
]
assert
set
(
F
.
asnumpy
(
F
.
unique
(
u
)))
==
{
0
,
1
}
assert
F
.
array_equal
(
F
.
astype
(
g
.
has_edges_between
(
u
,
v
),
F
.
int64
),
...
...
@@ -842,11 +877,14 @@ def _test_sample_neighbors_outedge(hypersparse):
_test1
(
"prob"
,
False
)
# w/o replacement
def
_test2
(
p
,
replace
):
# fanout > #neighbors
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
[
0
,
2
],
-
1
,
prob
=
p
,
replace
=
replace
,
edge_dir
=
"out"
)
if
not
fused
:
assert
subg
.
num_nodes
()
==
g
.
num_nodes
()
u
,
v
=
subg
.
edges
()
if
fused
:
u
,
v
=
subg
.
dstdata
[
dgl
.
NID
][
u
],
subg
.
srcdata
[
dgl
.
NID
][
v
]
u_ans
,
v_ans
,
e_ans
=
g
.
out_edges
([
0
,
2
],
form
=
"all"
)
if
p
is
not
None
:
emask
=
F
.
gather_row
(
g
.
edata
[
p
],
e_ans
)
...
...
@@ -859,13 +897,17 @@ def _test_sample_neighbors_outedge(hypersparse):
assert
uv
==
uv_ans
for
i
in
range
(
10
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
[
0
,
2
],
2
,
prob
=
p
,
replace
=
replace
,
edge_dir
=
"out"
)
if
not
fused
:
assert
subg
.
num_nodes
()
==
g
.
num_nodes
()
num_edges
=
4
if
replace
else
3
assert
subg
.
num_edges
()
==
num_edges
u
,
v
=
subg
.
edges
()
if
fused
:
u
,
v
=
subg
.
dstdata
[
dgl
.
NID
][
u
],
subg
.
srcdata
[
dgl
.
NID
][
v
]
assert
set
(
F
.
asnumpy
(
F
.
unique
(
u
)))
==
{
0
,
2
}
assert
F
.
array_equal
(
F
.
astype
(
g
.
has_edges_between
(
u
,
v
),
F
.
int64
),
...
...
@@ -885,7 +927,7 @@ def _test_sample_neighbors_outedge(hypersparse):
_test2
(
"prob"
,
False
)
# w/o replacement
def
_test3
(
p
,
replace
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
hg
,
{
"user"
:
[
0
,
1
],
"game"
:
0
},
-
1
,
...
...
@@ -893,7 +935,11 @@ def _test_sample_neighbors_outedge(hypersparse):
replace
=
replace
,
edge_dir
=
"out"
,
)
if
not
fused
:
assert
len
(
subg
.
ntypes
)
==
3
assert
len
(
subg
.
srctypes
)
==
3
assert
len
(
subg
.
dsttypes
)
==
3
assert
len
(
subg
.
etypes
)
==
4
assert
subg
[
"follow"
].
num_edges
()
==
6
if
p
is
None
else
4
assert
subg
[
"play"
].
num_edges
()
==
1
...
...
@@ -901,7 +947,7 @@ def _test_sample_neighbors_outedge(hypersparse):
assert
subg
[
"flips"
].
num_edges
()
==
0
for
i
in
range
(
10
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
subg
=
sample_neighbors
_fusing_mode
[
fused
]
(
hg
,
{
"user"
:
[
0
,
1
],
"game"
:
0
},
2
,
...
...
@@ -909,7 +955,10 @@ def _test_sample_neighbors_outedge(hypersparse):
replace
=
replace
,
edge_dir
=
"out"
,
)
if
not
fused
:
assert
len
(
subg
.
ntypes
)
==
3
assert
len
(
subg
.
srctypes
)
==
3
assert
len
(
subg
.
dsttypes
)
==
3
assert
len
(
subg
.
etypes
)
==
4
assert
subg
[
"follow"
].
num_edges
()
==
4
assert
subg
[
"play"
].
num_edges
()
==
2
if
replace
else
1
...
...
@@ -1077,7 +1126,9 @@ def _test_sample_neighbors_topk_outedge(hypersparse):
def
test_sample_neighbors_noprob
():
_test_sample_neighbors
(
False
,
None
)
_test_sample_neighbors
(
False
,
None
,
False
)
if
F
.
_default_context_str
!=
"gpu"
and
F
.
backend_name
==
"pytorch"
:
_test_sample_neighbors
(
False
,
None
,
True
)
# _test_sample_neighbors(True)
...
...
@@ -1086,7 +1137,9 @@ def test_sample_labors_noprob():
def
test_sample_neighbors_prob
():
_test_sample_neighbors
(
False
,
"prob"
)
_test_sample_neighbors
(
False
,
"prob"
,
False
)
if
F
.
_default_context_str
!=
"gpu"
and
F
.
backend_name
==
"pytorch"
:
_test_sample_neighbors
(
False
,
"prob"
,
True
)
# _test_sample_neighbors(True)
...
...
@@ -1095,7 +1148,9 @@ def test_sample_labors_prob():
def
test_sample_neighbors_outedge
():
_test_sample_neighbors_outedge
(
False
)
_test_sample_neighbors_outedge
(
False
,
False
)
if
F
.
_default_context_str
!=
"gpu"
and
F
.
backend_name
==
"pytorch"
:
_test_sample_neighbors_outedge
(
False
,
True
)
# _test_sample_neighbors_outedge(True)
...
...
@@ -1107,7 +1162,9 @@ def test_sample_neighbors_outedge():
reason
=
"GPU sample neighbors with mask not implemented"
,
)
def
test_sample_neighbors_mask
():
_test_sample_neighbors
(
False
,
"mask"
)
_test_sample_neighbors
(
False
,
"mask"
,
False
)
if
F
.
_default_context_str
!=
"gpu"
and
F
.
backend_name
==
"pytorch"
:
_test_sample_neighbors
(
False
,
"mask"
,
True
)
@
unittest
.
skipIf
(
...
...
@@ -1128,21 +1185,26 @@ def test_sample_neighbors_topk_outedge():
# _test_sample_neighbors_topk_outedge(True)
def
test_sample_neighbors_with_0deg
():
@
pytest
.
mark
.
parametrize
(
"fused"
,
[
False
,
True
])
def
test_sample_neighbors_with_0deg
(
fused
):
if
fused
and
(
F
.
_default_context_str
==
"gpu"
or
F
.
backend_name
!=
"pytorch"
):
pytest
.
skip
(
"Fused sampling support CPU with backend PyTorch."
)
g
=
dgl
.
graph
(([],
[]),
num_nodes
=
5
).
to
(
F
.
ctx
())
sg
=
dgl
.
sampling
.
sample_neighbors
(
sg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
"in"
,
replace
=
False
)
assert
sg
.
num_edges
()
==
0
sg
=
dgl
.
sampling
.
sample_neighbors
(
sg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
"in"
,
replace
=
True
)
assert
sg
.
num_edges
()
==
0
sg
=
dgl
.
sampling
.
sample_neighbors
(
sg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
"out"
,
replace
=
False
)
assert
sg
.
num_edges
()
==
0
sg
=
dgl
.
sampling
.
sample_neighbors
(
sg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
F
.
tensor
([
1
,
2
],
dtype
=
F
.
int64
),
2
,
edge_dir
=
"out"
,
replace
=
True
)
assert
sg
.
num_edges
()
==
0
...
...
@@ -1274,7 +1336,7 @@ def test_sample_neighbors_biased_homogeneous():
)
def
test_sample_neighbors_biased_bipartite
():
g
=
create_test_graph
(
100
,
30
,
True
)
num_dst
=
g
.
num
ber_of
_dst_nodes
()
num_dst
=
g
.
num_dst_nodes
()
bias
=
F
.
tensor
([
0
,
0.01
,
10
,
10
],
dtype
=
F
.
float32
)
def
check_num
(
nodes
,
tag
):
...
...
@@ -1492,7 +1554,12 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"int32"
,
"int64"
])
def
test_sample_neighbors_exclude_edges_heteroG
(
dtype
):
@
pytest
.
mark
.
parametrize
(
"fused"
,
[
False
,
True
])
def
test_sample_neighbors_exclude_edges_heteroG
(
dtype
,
fused
):
if
fused
and
(
F
.
_default_context_str
==
"gpu"
or
F
.
backend_name
!=
"pytorch"
):
pytest
.
skip
(
"Fused sampling support CPU with backend PyTorch."
)
d_i_d_u_nodes
=
F
.
zerocopy_from_numpy
(
np
.
unique
(
np
.
random
.
randint
(
300
,
size
=
100
,
dtype
=
dtype
))
)
...
...
@@ -1565,7 +1632,7 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype):
(
"drug"
,
"treats"
,
"disease"
):
excluded_d_t_d_edges
,
}
sg
=
dgl
.
sampling
.
sample_neighbors
(
sg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
{
"drug"
:
sampled_drug_node
,
...
...
@@ -1576,6 +1643,48 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype):
exclude_edges
=
excluded_edges
,
)
if
fused
:
def
contain_edge
(
g
,
sg
,
etype
,
u
,
v
):
# set of subgraph graph edges deduced from original graph
org_edges
=
set
(
map
(
tuple
,
np
.
stack
(
g
.
find_edges
(
sg
.
edges
[
etype
].
data
[
dgl
.
EID
],
etype
),
axis
=
1
,
),
)
)
# set of excluded edges
excluded_edges
=
set
(
map
(
tuple
,
np
.
stack
((
u
,
v
),
axis
=
1
)))
diff_set
=
org_edges
-
excluded_edges
return
len
(
diff_set
)
!=
len
(
org_edges
)
assert
not
contain_edge
(
g
,
sg
,
(
"drug"
,
"interacts"
,
"drug"
),
did_excluded_nodes_U
,
did_excluded_nodes_V
,
)
assert
not
contain_edge
(
g
,
sg
,
(
"drug"
,
"interacts"
,
"gene"
),
dig_excluded_nodes_U
,
dig_excluded_nodes_V
,
)
assert
not
contain_edge
(
g
,
sg
,
(
"drug"
,
"treats"
,
"disease"
),
dtd_excluded_nodes_U
,
dtd_excluded_nodes_V
,
)
else
:
assert
not
np
.
any
(
F
.
asnumpy
(
sg
.
has_edges_between
(
...
...
@@ -1606,7 +1715,12 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"int32"
,
"int64"
])
def
test_sample_neighbors_exclude_edges_homoG
(
dtype
):
@
pytest
.
mark
.
parametrize
(
"fused"
,
[
False
,
True
])
def
test_sample_neighbors_exclude_edges_homoG
(
dtype
,
fused
):
if
fused
and
(
F
.
_default_context_str
==
"gpu"
or
F
.
backend_name
!=
"pytorch"
):
pytest
.
skip
(
"Fused sampling support CPU with backend PyTorch."
)
u_nodes
=
F
.
zerocopy_from_numpy
(
np
.
unique
(
np
.
random
.
randint
(
300
,
size
=
100
,
dtype
=
dtype
))
)
...
...
@@ -1629,10 +1743,30 @@ def test_sample_neighbors_exclude_edges_homoG(dtype):
excluded_nodes_U
=
g_edges
[
U
][
b_idx
:
e_idx
]
excluded_nodes_V
=
g_edges
[
V
][
b_idx
:
e_idx
]
sg
=
dgl
.
sampling
.
sample_neighbors
(
sg
=
sample_neighbors
_fusing_mode
[
fused
]
(
g
,
sampled_node
,
sampled_amount
,
exclude_edges
=
excluded_edges
)
if
fused
:
def
contain_edge
(
g
,
sg
,
u
,
v
):
# set of subgraph graph edges deduced from original graph
org_edges
=
set
(
map
(
tuple
,
np
.
stack
(
g
.
find_edges
(
sg
.
edges
[
"_E"
].
data
[
dgl
.
EID
]),
axis
=
1
),
)
)
# set of excluded edges
excluded_edges
=
set
(
map
(
tuple
,
np
.
stack
((
u
,
v
),
axis
=
1
)))
diff_set
=
org_edges
-
excluded_edges
return
len
(
diff_set
)
!=
len
(
org_edges
)
assert
not
contain_edge
(
g
,
sg
,
excluded_nodes_U
,
excluded_nodes_V
)
else
:
assert
not
np
.
any
(
F
.
asnumpy
(
sg
.
has_edges_between
(
excluded_nodes_U
,
excluded_nodes_V
))
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment