Unverified Commit 4135b1bd authored by AdamGrabowski's avatar AdamGrabowski Committed by GitHub
Browse files

[Performance] Fused sampling with compaction (#5924)


Co-authored-by: default avatarHesham Mostafa <hesham.mostafa@intel.com>
parent 4ceb0bff
import time
import dgl
import dgl.function as fn
import numpy as np
import torch
from .. import utils
@utils.benchmark("time")
@utils.parametrize_cpu("graph_name", ["livejournal", "reddit"])
@utils.parametrize_gpu("graph_name", ["ogbn-arxiv", "reddit"])
@utils.parametrize("format", ["csr", "csc"])
@utils.parametrize("seed_nodes_num", [200, 5000, 20000])
@utils.parametrize("fanout", [5, 20, 40])
def track_time(graph_name, format, seed_nodes_num, fanout):
device = utils.get_bench_device()
graph = utils.get_graph(graph_name, format).to(device)
edge_dir = "in" if format == "csc" else "out"
seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)
seed_nodes = torch.from_numpy(seed_nodes).to(device)
# dry run
for i in range(3):
dgl.sampling.sample_neighbors_fused(
graph, seed_nodes, fanout, edge_dir=edge_dir
)
# timing
with utils.Timer() as t:
for i in range(50):
dgl.sampling.sample_neighbors_fused(
graph, seed_nodes, fanout, edge_dir=edge_dir
)
return t.elapsed_secs / 50
...@@ -572,6 +572,72 @@ COOMatrix CSRRowWiseSampling( ...@@ -572,6 +572,72 @@ COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, CSRMatrix mat, IdArray rows, int64_t num_samples,
NDArray prob_or_mask = NDArray(), bool replace = true); NDArray prob_or_mask = NDArray(), bool replace = true);
/*!
* @brief Randomly select a fixed number of non-zero entries along each given
* row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a CSR matrix, with
* additional IdArray that is an extended version of CSR's index pointers.
*
* With template parameter set to True rows are also saved as new seed nodes and
* mapped
*
* If replace is false and a row has fewer non-zero values than num_samples,
* all the values are picked.
*
* Examples:
*
* // csr.num_rows = 4;
* // csr.num_cols = 4;
* // csr.indptr = [0, 2, 3, 3, 5]
* // csr.indices = [0, 1, 1, 2, 3]
* // csr.data = [2, 3, 0, 1, 4]
* CSRMatrix csr = ...;
* IdArray rows = ... ; // [1, 3]
* IdArray seed_mapping = [-1, -1, -1, -1];
* std::vector<IdType> new_seed_nodes = {};
*
* std::pair<CSRMatrix, IdArray> sampled = CSRRowWiseSamplingFused<
* typename IdType, True>(
* csr, rows, seed_mapping,
* new_seed_nodes, 2,
* FloatArray(), false);
* // possible sampled csr matrix:
* // sampled.first.num_rows = 2
* // sampled.first.num_cols = 3
* // sampled.first.indptr = [0, 1, 3]
* // sampled.first.indices = [1, 2, 3]
* // sampled.first.data = [0, 1, 4]
* // sampled.second = [0, 1, 1]
* // seed_mapping = [-1, 0, -1, 1];
* // new_seed_nodes = {1, 3};
*
* @tparam IdType Graph's index data type, can be int32_t or int64_t
* @tparam map_seed_nodes If set for true we map and copy rows to new_seed_nodes
* @param mat Input CSR matrix.
* @param rows Rows to sample from.
* @param seed_mapping Mapping array used if map_seed_nodes=true. If so each row
* from rows will be set to its position e.g. mapping[rows[i]] = i.
* @param new_seed_nodes Vector used if map_seed_nodes=true. If so it will
* contain rows.
* @param rows Rows to sample from.
* @param num_samples Number of samples
* @param prob_or_mask Unnormalized probability array or mask array.
* Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* @param replace True if sample with replacement
* @return A CSRMatrix storing the picked row, col and data indices,
* COO version of picked rows
* @note The edges of the entire graph must be ordered by their edge types,
* rows must be unique
*/
template <typename IdType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
std::vector<IdType>* new_seed_nodes, int64_t num_samples,
NDArray prob_or_mask = NDArray(), bool replace = true);
/** /**
* @brief Randomly select a fixed number of non-zero entries for each edge type * @brief Randomly select a fixed number of non-zero entries for each edge type
* along each given row independently. * along each given row independently.
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <tuple>
#include <vector> #include <vector>
namespace dgl { namespace dgl {
...@@ -47,6 +48,55 @@ HeteroSubgraph SampleNeighbors( ...@@ -47,6 +48,55 @@ HeteroSubgraph SampleNeighbors(
const std::vector<FloatArray>& probability, const std::vector<FloatArray>& probability,
const std::vector<IdArray>& exclude_edges, bool replace = true); const std::vector<IdArray>& exclude_edges, bool replace = true);
/**
* @brief Sample from the neighbors of the given nodes and convert a graph into
* a bipartite-structured graph for message passing.
*
* Specifically, we create one node type \c ntype_l on the "left" side and
* another node type \c ntype_r on the "right" side for each node type \c ntype.
* The nodes of type \c ntype_r would contain the nodes designated by the
* caller, and node type \c ntype_l would contain the nodes that has an edge
* connecting to one of the designated nodes.
*
* The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r.
* When sampling with replacement, the sampled subgraph could have parallel
* edges.
*
* For sampling without replace, if fanout > the number of neighbors, all the
* neighbors will be sampled.
*
* Non-deterministic algorithm, requires nodes parameter to store unique Node
* IDs.
*
* @tparam IdType Graph's index data type, can be int32_t or int64_t
* @param hg The input graph.
* @param nodes Node IDs of each type. The vector length must be equal to the
* number of node types. Empty array is allowed.
* @param mapping External parameter that should be set to a vector of IdArrays
* filled with -1, required for mapping of nodes in returned
* graph
* @param fanouts Number of sampled neighbors for each edge type. The vector
* length should be equal to the number of edge types, or one if they all have
* the same fanout.
* @param dir Edge direction.
* @param probability A vector of 1D float arrays, indicating the transition
* probability of each edge by edge type. An empty float array assumes uniform
* transition.
* @param exclude_edges Edges IDs of each type which will be excluded during
* sampling. The vector length must be equal to the number of edges types. Empty
* array is allowed.
* @param replace If true, sample with replacement.
* @return Sampled neighborhoods as a graph. The return graph has the same
* schema as the original one.
*/
template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
SampleNeighborsFused(
const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
const std::vector<IdArray>& mapping, const std::vector<int64_t>& fanouts,
EdgeDir dir, const std::vector<NDArray>& prob_or_mask,
const std::vector<IdArray>& exclude_edges, bool replace = true);
/** /**
* Select the neighbors with k-largest weights on the connecting edges for each * Select the neighbors with k-largest weights on the connecting edges for each
* given node. * given node.
......
"""Data loading components for neighbor sampling""" """Data loading components for neighbor sampling"""
from .. import backend as F
from ..base import EID, NID from ..base import EID, NID
from ..heterograph import DGLGraph
from ..transforms import to_block from ..transforms import to_block
from .base import BlockSampler from .base import BlockSampler
...@@ -54,6 +56,9 @@ class NeighborSampler(BlockSampler): ...@@ -54,6 +56,9 @@ class NeighborSampler(BlockSampler):
output_device : device, optional output_device : device, optional
The device of the output subgraphs or MFGs. Default is the same as the The device of the output subgraphs or MFGs. Default is the same as the
minibatch of seed nodes. minibatch of seed nodes.
fused : bool, default True
If True and device is CPU fused sample neighbors is invoked. This version
requires seed_nodes to be unique
Examples Examples
-------- --------
...@@ -120,6 +125,7 @@ class NeighborSampler(BlockSampler): ...@@ -120,6 +125,7 @@ class NeighborSampler(BlockSampler):
prefetch_labels=None, prefetch_labels=None,
prefetch_edge_feats=None, prefetch_edge_feats=None,
output_device=None, output_device=None,
fused=True,
): ):
super().__init__( super().__init__(
prefetch_node_feats=prefetch_node_feats, prefetch_node_feats=prefetch_node_feats,
...@@ -137,10 +143,43 @@ class NeighborSampler(BlockSampler): ...@@ -137,10 +143,43 @@ class NeighborSampler(BlockSampler):
) )
self.prob = prob or mask self.prob = prob or mask
self.replace = replace self.replace = replace
self.fused = fused
self.mapping = {}
self.g = None
def sample_blocks(self, g, seed_nodes, exclude_eids=None): def sample_blocks(self, g, seed_nodes, exclude_eids=None):
output_nodes = seed_nodes output_nodes = seed_nodes
blocks = [] blocks = []
if self.fused:
cpu = F.device_type(g.device) == "cpu"
if isinstance(seed_nodes, dict):
for ntype in list(seed_nodes.keys()):
if not cpu:
break
cpu = (
cpu and F.device_type(seed_nodes[ntype].device) == "cpu"
)
else:
cpu = cpu and F.device_type(seed_nodes.device) == "cpu"
if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch":
if self.g != g:
self.mapping = {}
self.g = g
for fanout in reversed(self.fanouts):
block = g.sample_neighbors_fused(
seed_nodes,
fanout,
edge_dir=self.edge_dir,
prob=self.prob,
replace=self.replace,
exclude_edges=exclude_eids,
mapping=self.mapping,
)
seed_nodes = block.srcdata[NID]
blocks.insert(0, block)
return seed_nodes, output_nodes, blocks
for fanout in reversed(self.fanouts): for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors( frontier = g.sample_neighbors(
seed_nodes, seed_nodes,
......
"""Neighbor sampling APIs""" """Neighbor sampling APIs"""
import os
import torch
from .. import backend as F, ndarray as nd, utils from .. import backend as F, ndarray as nd, utils
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import DGLError, EID from ..base import DGLError, EID
from ..heterograph import DGLGraph from ..heterograph import DGLBlock, DGLGraph
from .utils import EidExcluder from .utils import EidExcluder
__all__ = [ __all__ = [
"sample_etype_neighbors", "sample_etype_neighbors",
"sample_neighbors", "sample_neighbors",
"sample_neighbors_fused",
"sample_neighbors_biased", "sample_neighbors_biased",
"select_topk", "select_topk",
] ]
...@@ -379,6 +384,126 @@ def sample_neighbors( ...@@ -379,6 +384,126 @@ def sample_neighbors(
return frontier if output_device is None else frontier.to(output_device) return frontier if output_device is None else frontier.to(output_device)
def sample_neighbors_fused(
g,
nodes,
fanout,
edge_dir="in",
prob=None,
replace=False,
copy_ndata=True,
copy_edata=True,
exclude_edges=None,
mapping=None,
):
"""Sample neighboring edges of the given nodes and return the induced subgraph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
will be randomly chosen. The graph returned will then contain all the nodes in the
original graph, but only the sampled edges. Nodes will be renumbered starting from id 0,
which would be new node id of first seed node.
Parameters
----------
g : DGLGraph
The graph. Can be either on CPU or GPU.
nodes : tensor or dict
Node IDs to sample neighbors from.
This argument can take a single ID tensor or a dictionary of node types and ID tensors.
If a single tensor is given, the graph must only have one type of nodes.
fanout : int or dict[etype, int]
The number of edges to be sampled for each node on each edge type.
This argument can take a single int or a dictionary of edge types and ints.
If a single int is given, DGL will sample this number of edges for each node for
every edge type.
If -1 is given for a single edge type, all the neighboring edges with that edge
type and non-zero probability will be selected.
edge_dir : str, optional
Determines whether to sample inbound or outbound edges.
Can take either ``in`` for inbound edges or ``out`` for outbound edges.
prob : str, optional
Feature name used as the (unnormalized) probabilities associated with each
neighboring edge of a node. The feature must have only one element for each
edge.
The features must be non-negative floats or boolean. Otherwise, the result
will be undefined.
exclude_edges: tensor or dict
Edge IDs to exclude during sampling neighbors for the seed nodes.
This argument can take a single ID tensor or a dictionary of edge types and ID tensors.
If a single tensor is given, the graph must only have one type of nodes.
replace : bool, optional
If True, sample with replacement.
copy_ndata: bool, optional
If True, the node features of the new graph are copied from
the original graph. If False, the new graph will not have any
node features.
(Default: True)
copy_edata: bool, optional
If True, the edge features of the new graph are copied from
the original graph. If False, the new graph will not have any
edge features.
(Default: False)
mapping : dictionary, optional
Used by fused version of NeighborSampler. To avoid constant data allocation
provide empty dictionary ({}) that will be allocated once with proper data and reused
by each function call
(Default: None)
Returns
-------
DGLGraph
A sampled subgraph containing only the sampled neighboring edges.
Notes
-----
If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as
the node or edge features of the original graph and the new graph.
As a result, users should avoid performing in-place operations
on the node features of the new graph to avoid feature corruption.
"""
if not g.is_pinned():
frontier = _sample_neighbors(
g,
nodes,
fanout,
edge_dir=edge_dir,
prob=prob,
replace=replace,
copy_ndata=copy_ndata,
copy_edata=copy_edata,
exclude_edges=exclude_edges,
fused=True,
mapping=mapping,
)
else:
frontier = _sample_neighbors(
g,
nodes,
fanout,
edge_dir=edge_dir,
prob=prob,
replace=replace,
copy_ndata=copy_ndata,
copy_edata=copy_edata,
fused=True,
mapping=mapping,
)
if exclude_edges is not None:
eid_excluder = EidExcluder(exclude_edges)
frontier = eid_excluder(frontier)
return frontier
def _sample_neighbors( def _sample_neighbors(
g, g,
nodes, nodes,
...@@ -390,6 +515,8 @@ def _sample_neighbors( ...@@ -390,6 +515,8 @@ def _sample_neighbors(
copy_edata=True, copy_edata=True,
_dist_training=False, _dist_training=False,
exclude_edges=None, exclude_edges=None,
fused=False,
mapping=None,
): ):
if not isinstance(nodes, dict): if not isinstance(nodes, dict):
if len(g.ntypes) > 1: if len(g.ntypes) > 1:
...@@ -446,6 +573,53 @@ def _sample_neighbors( ...@@ -446,6 +573,53 @@ def _sample_neighbors(
else: else:
excluded_edges_all_t.append(nd.array([], ctx=ctx)) excluded_edges_all_t.append(nd.array([], ctx=ctx))
if fused:
if _dist_training:
raise DGLError(
"distributed training not supported in fused sampling"
)
cpu = F.device_type(g.device) == "cpu"
if isinstance(nodes, dict):
for ntype in list(nodes.keys()):
if not cpu:
break
cpu = cpu and F.device_type(nodes[ntype].device) == "cpu"
else:
cpu = cpu and F.device_type(nodes.device) == "cpu"
if not cpu or F.backend_name != "pytorch":
raise DGLError(
"Only PyTorch backend and cpu is supported in fused sampling"
)
if mapping is None:
mapping = {}
mapping_name = "__mapping" + str(os.getpid())
if mapping_name not in mapping.keys():
mapping[mapping_name] = [
torch.LongTensor(g.num_nodes(ntype)).fill_(-1)
for ntype in g.ntypes
]
subgidx, induced_nodes, induced_edges = _CAPI_DGLSampleNeighborsFused(
g._graph,
nodes_all_types,
[F.to_dgl_nd(m) for m in mapping[mapping_name]],
fanout_array,
edge_dir,
prob_arrays,
excluded_edges_all_t,
replace,
)
for mapping_vector, src_nodes in zip(
mapping[mapping_name], induced_nodes
):
mapping_vector[F.from_dgl_nd(src_nodes).type(F.int64)] = -1
new_ntypes = (g.ntypes, g.ntypes)
ret = DGLBlock(subgidx, new_ntypes, g.etypes)
assert ret.is_unibipartite
else:
subgidx = _CAPI_DGLSampleNeighbors( subgidx = _CAPI_DGLSampleNeighbors(
g._graph, g._graph,
nodes_all_types, nodes_all_types,
...@@ -455,8 +629,8 @@ def _sample_neighbors( ...@@ -455,8 +629,8 @@ def _sample_neighbors(
excluded_edges_all_t, excluded_edges_all_t,
replace, replace,
) )
induced_edges = subgidx.induced_edges
ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)
induced_edges = subgidx.induced_edges
# handle features # handle features
# (TODO) (BarclayII) DGL distributed fails with bus error, freezes, or other # (TODO) (BarclayII) DGL distributed fails with bus error, freezes, or other
...@@ -465,12 +639,31 @@ def _sample_neighbors( ...@@ -465,12 +639,31 @@ def _sample_neighbors(
# only set the edge IDs. # only set the edge IDs.
if not _dist_training: if not _dist_training:
if copy_ndata: if copy_ndata:
if fused:
src_node_ids = [F.from_dgl_nd(src) for src in induced_nodes]
dst_node_ids = [
utils.toindex(
nodes.get(ntype, []), g._idtype_str
).tousertensor(ctx=F.to_backend_ctx(g._graph.ctx))
for ntype in g.ntypes
]
node_frames = utils.extract_node_subframes_for_block(
g, src_node_ids, dst_node_ids
)
utils.set_new_frames(ret, node_frames=node_frames)
else:
node_frames = utils.extract_node_subframes(g, device) node_frames = utils.extract_node_subframes(g, device)
utils.set_new_frames(ret, node_frames=node_frames) utils.set_new_frames(ret, node_frames=node_frames)
if copy_edata: if copy_edata:
if fused:
edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges]
edge_frames = utils.extract_edge_subframes(g, edge_ids)
utils.set_new_frames(ret, edge_frames=edge_frames)
else:
edge_frames = utils.extract_edge_subframes(g, induced_edges) edge_frames = utils.extract_edge_subframes(g, induced_edges)
utils.set_new_frames(ret, edge_frames=edge_frames) utils.set_new_frames(ret, edge_frames=edge_frames)
else: else:
for i, etype in enumerate(ret.canonical_etypes): for i, etype in enumerate(ret.canonical_etypes):
ret.edges[etype].data[EID] = induced_edges[i] ret.edges[etype].data[EID] = induced_edges[i]
...@@ -479,6 +672,7 @@ def _sample_neighbors( ...@@ -479,6 +672,7 @@ def _sample_neighbors(
DGLGraph.sample_neighbors = utils.alias_func(sample_neighbors) DGLGraph.sample_neighbors = utils.alias_func(sample_neighbors)
DGLGraph.sample_neighbors_fused = utils.alias_func(sample_neighbors_fused)
def sample_neighbors_biased( def sample_neighbors_biased(
......
...@@ -597,6 +597,47 @@ COOMatrix CSRRowWiseSampling( ...@@ -597,6 +597,47 @@ COOMatrix CSRRowWiseSampling(
return ret; return ret;
} }
template <typename IdType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
std::vector<IdType>* new_seed_nodes, int64_t num_samples,
NDArray prob_or_mask, bool replace) {
std::pair<CSRMatrix, IdArray> ret;
if (IsNullArray(prob_or_mask)) {
ATEN_XPU_SWITCH(
rows->ctx.device_type, XPU, "CSRRowWiseSamplingUniformFused", {
ret =
impl::CSRRowWiseSamplingUniformFused<XPU, IdType, map_seed_nodes>(
mat, rows, seed_mapping, new_seed_nodes, num_samples,
replace);
});
} else {
CHECK_VALID_CONTEXT(prob_or_mask, rows);
ATEN_XPU_SWITCH(rows->ctx.device_type, XPU, "CSRRowWiseSamplingFused", {
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask->dtype, FloatType, "probability or mask", {
ret = impl::CSRRowWiseSamplingFused<
XPU, IdType, FloatType, map_seed_nodes>(
mat, rows, seed_mapping, new_seed_nodes, num_samples,
prob_or_mask, replace);
});
});
}
return ret;
}
template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int64_t, true>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int64_t, false>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int32_t, true>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int32_t, false>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
......
...@@ -178,6 +178,14 @@ COOMatrix CSRRowWiseSampling( ...@@ -178,6 +178,14 @@ COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace); bool replace);
// FloatType is the type of probability data.
template <
DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
std::vector<IdxType>* new_seed_nodes, int64_t num_samples,
NDArray prob_or_mask, bool replace);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
...@@ -190,6 +198,11 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -190,6 +198,11 @@ template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform( COOMatrix CSRRowWiseSamplingUniform(
CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace); CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);
template <DGLDeviceType XPU, typename IdType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingUniformFused(
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
std::vector<IdType>* new_seed_nodes, int64_t num_samples, bool replace);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform( COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
......
...@@ -223,5 +223,27 @@ ConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) { ...@@ -223,5 +223,27 @@ ConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) {
template class ConcurrentIdHashMap<int32_t>; template class ConcurrentIdHashMap<int32_t>;
template class ConcurrentIdHashMap<int64_t>; template class ConcurrentIdHashMap<int64_t>;
template <typename IdType>
bool BoolCompareAndSwap(IdType* ptr) {
#ifdef _MSC_VER
if (sizeof(IdType) == 4) {
return _InterlockedCompareExchange(reinterpret_cast<LONG*>(ptr), 0, -1) ==
-1;
} else if (sizeof(IdType) == 8) {
return _InterlockedCompareExchange64(
reinterpret_cast<LONGLONG*>(ptr), 0, -1) == -1;
} else {
LOG(FATAL) << "ID can only be int32 or int64";
}
#elif __GNUC__ // _MSC_VER
return __sync_bool_compare_and_swap(ptr, -1, 0);
#else // _MSC_VER
#error "CompareAndSwap is not supported on this platform."
#endif // _MSC_VER
}
template bool BoolCompareAndSwap<int32_t>(int32_t*);
template bool BoolCompareAndSwap<int64_t>(int64_t*);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -195,6 +195,9 @@ class ConcurrentIdHashMap { ...@@ -195,6 +195,9 @@ class ConcurrentIdHashMap {
IdType mask_; IdType mask_;
}; };
template <typename IdType>
bool BoolCompareAndSwap(IdType* ptr);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
namespace dgl { namespace dgl {
...@@ -94,6 +95,115 @@ using EtypeRangePickFn = std::function<void( ...@@ -94,6 +95,115 @@ using EtypeRangePickFn = std::function<void(
const std::vector<IdxType>& et_idx, const std::vector<IdxType>& et_eid, const std::vector<IdxType>& et_idx, const std::vector<IdxType>& et_eid,
const IdxType* eid, IdxType* out_idx)>; const IdxType* eid, IdxType* out_idx)>;
template <typename IdxType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWisePickFused(
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
std::vector<IdxType>* new_seed_nodes, int64_t num_picks, bool replace,
PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
using namespace aten;
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
const IdxType* data =
CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;
const IdxType* rows_data = static_cast<IdxType*>(rows->data);
const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx;
const auto& idtype = mat.indptr->dtype;
IdxType* seed_mapping_data = nullptr;
if (map_seed_nodes) seed_mapping_data = seed_mapping.Ptr<IdxType>();
const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
std::vector<int64_t> global_prefix(num_threads + 1, 0);
IdArray picked_col, picked_idx, picked_coo_rows;
IdArray block_csr_indptr = IdArray::Empty({num_rows + 1}, idtype, ctx);
IdxType* block_csr_indptr_data = block_csr_indptr.Ptr<IdxType>();
#pragma omp parallel num_threads(num_threads)
{
const int thread_id = omp_get_thread_num();
const int64_t start_i =
thread_id * (num_rows / num_threads) +
std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
const int64_t end_i =
(thread_id + 1) * (num_rows / num_threads) +
std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
assert(thread_id + 1 < num_threads || end_i == num_rows);
const int64_t num_local = end_i - start_i;
std::unique_ptr<int64_t[]> local_prefix(new int64_t[num_local + 1]);
local_prefix[0] = 0;
for (int64_t i = start_i; i < end_i; ++i) {
// build prefix-sum
const int64_t local_i = i - start_i;
const IdxType rid = rows_data[i];
if (map_seed_nodes) seed_mapping_data[rid] = i;
IdxType len = num_picks_fn(
rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
local_prefix[local_i + 1] = local_prefix[local_i] + len;
}
global_prefix[thread_id + 1] = local_prefix[num_local];
#pragma omp barrier
#pragma omp master
{
for (int t = 0; t < num_threads; ++t) {
global_prefix[t + 1] += global_prefix[t];
}
picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
picked_coo_rows =
IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
}
#pragma omp barrier
IdxType* picked_cdata = picked_col.Ptr<IdxType>();
IdxType* picked_idata = picked_idx.Ptr<IdxType>();
IdxType* picked_rows = picked_coo_rows.Ptr<IdxType>();
const IdxType thread_offset = global_prefix[thread_id];
for (int64_t i = start_i; i < end_i; ++i) {
const IdxType rid = rows_data[i];
const int64_t local_i = i - start_i;
block_csr_indptr_data[i] = local_prefix[local_i] + thread_offset;
const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off;
if (len == 0) continue;
const int64_t row_offset = local_prefix[local_i] + thread_offset;
const int64_t num_picks =
local_prefix[local_i + 1] + thread_offset - row_offset;
pick_fn(
rid, off, len, num_picks, indices, data, picked_idata + row_offset);
for (int64_t j = 0; j < num_picks; ++j) {
const IdxType picked = picked_idata[row_offset + j];
picked_cdata[row_offset + j] = indices[picked];
picked_idata[row_offset + j] = data ? data[picked] : picked;
picked_rows[row_offset + j] = i;
}
}
}
block_csr_indptr_data[num_rows] = global_prefix.back();
const IdxType num_cols = picked_col->shape[0];
if (map_seed_nodes) {
(*new_seed_nodes).resize(num_rows);
memcpy((*new_seed_nodes).data(), rows_data, sizeof(IdxType) * num_rows);
}
return std::make_pair(
CSRMatrix(num_rows, num_cols, block_csr_indptr, picked_col, picked_idx),
picked_coo_rows);
}
// Template for picking non-zero values row-wise. The implementation utilizes // Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation // OpenMP parallelization on rows because each row performs computation
// independently. // independently.
......
...@@ -225,6 +225,74 @@ template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>( ...@@ -225,6 +225,74 @@ template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>( template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool); CSRMatrix, IdArray, int64_t, NDArray, bool);
template <
DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
std::vector<IdxType>* new_seed_nodes, int64_t num_samples,
NDArray prob_or_mask, bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
CHECK(prob_or_mask.defined());
auto num_picks_fn =
GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
auto pick_fn =
GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
return CSRRowWisePickFused<IdxType, map_seed_nodes>(
mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn,
num_picks_fn);
}
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, float, true>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, float, true>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, double, true>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, double, true>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, int8_t, true>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, int8_t, true>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, uint8_t, true>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, uint8_t, true>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, float, false>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, float, false>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, double, false>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, double, false>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, int8_t, false>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, int8_t, false>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, uint8_t, false>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, uint8_t, false>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template <DGLDeviceType XPU, typename IdxType, typename DType> template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
...@@ -283,6 +351,33 @@ template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>( ...@@ -283,6 +351,33 @@ template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>( template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, int64_t, bool); CSRMatrix, IdArray, int64_t, bool);
template <DGLDeviceType XPU, typename IdxType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingUniformFused(
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
std::vector<IdxType>* new_seed_nodes, int64_t num_samples, bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
auto num_picks_fn =
GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return CSRRowWisePickFused<IdxType, map_seed_nodes>(
mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn,
num_picks_fn);
}
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int32_t, true>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int64_t, true>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int32_t, false>(
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int64_t, false>(
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, bool);
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform( COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
......
...@@ -6,13 +6,16 @@ ...@@ -6,13 +6,16 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/aten/macro.h> #include <dgl/aten/macro.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/sampling/neighbor.h> #include <dgl/sampling/neighbor.h>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include "../../../array/cpu/concurrent_id_hash_map.h"
#include "../../../c_api_common.h" #include "../../../c_api_common.h"
#include "../../unit_graph.h" #include "../../unit_graph.h"
...@@ -22,6 +25,76 @@ using namespace dgl::aten; ...@@ -22,6 +25,76 @@ using namespace dgl::aten;
namespace dgl { namespace dgl {
namespace sampling { namespace sampling {
template <typename IdType>
void ExcludeCertainEdgesFused(
std::vector<CSRMatrix>* sampled_graphs, std::vector<IdArray>* induced_edges,
std::vector<IdArray>* sampled_coo_rows,
const std::vector<IdArray>& exclude_edges,
std::vector<FloatArray>* weights = nullptr) {
int etypes = (*sampled_graphs).size();
std::vector<IdArray> remain_induced_edges(etypes);
std::vector<IdArray> remain_indptrs(etypes);
std::vector<IdArray> remain_indices(etypes);
std::vector<IdArray> remain_coo_rows(etypes);
std::vector<FloatArray> remain_weights(etypes);
for (int etype = 0; etype < etypes; ++etype) {
if (exclude_edges[etype].GetSize() == 0 ||
(*sampled_graphs)[etype].num_rows == 0) {
remain_induced_edges[etype] = (*induced_edges)[etype];
if (weights) remain_weights[etype] = (*weights)[etype];
continue;
}
const auto dtype = weights && (*weights)[etype]->shape[0]
? (*weights)[etype]->dtype
: DGLDataType{kDGLFloat, 8 * sizeof(float), 1};
ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "weights", {
IdType* indptr = (*sampled_graphs)[etype].indptr.Ptr<IdType>();
IdType* indices = (*sampled_graphs)[etype].indices.Ptr<IdType>();
IdType* coo_rows = (*sampled_coo_rows)[etype].Ptr<IdType>();
IdType* induced_edges_data = (*induced_edges)[etype].Ptr<IdType>();
FloatType* weights_data = weights && (*weights)[etype]->shape[0]
? (*weights)[etype].Ptr<FloatType>()
: nullptr;
const IdType exclude_edges_len = exclude_edges[etype]->shape[0];
std::sort(
exclude_edges[etype].Ptr<IdType>(),
exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);
const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();
IdType outIndices = 0;
for (IdType row = 0; row < (*sampled_graphs)[etype].indptr->shape[0] - 1;
++row) {
auto tmp_row = indptr[row];
if (outIndices != indptr[row]) indptr[row] = outIndices;
for (IdType col = tmp_row; col < indptr[row + 1]; ++col) {
if (!std::binary_search(
exclude_edges_data, exclude_edges_data + exclude_edges_len,
induced_edges_data[col])) {
indices[outIndices] = indices[col];
induced_edges_data[outIndices] = induced_edges_data[col];
coo_rows[outIndices] = coo_rows[col];
if (weights_data) weights_data[outIndices] = weights_data[col];
++outIndices;
}
}
}
indptr[(*sampled_graphs)[etype].indptr->shape[0] - 1] = outIndices;
remain_induced_edges[etype] =
aten::IndexSelect((*induced_edges)[etype], 0, outIndices);
remain_weights[etype] =
weights_data ? aten::IndexSelect((*weights)[etype], 0, outIndices)
: NullArray();
remain_indices[etype] =
aten::IndexSelect((*sampled_graphs)[etype].indices, 0, outIndices);
(*sampled_coo_rows)[etype] =
aten::IndexSelect((*sampled_coo_rows)[etype], 0, outIndices);
(*sampled_graphs)[etype] = CSRMatrix(
(*sampled_graphs)[etype].num_rows, outIndices,
(*sampled_graphs)[etype].indptr, remain_indices[etype],
remain_induced_edges[etype]);
});
}
}
std::pair<HeteroSubgraph, std::vector<FloatArray>> ExcludeCertainEdges( std::pair<HeteroSubgraph, std::vector<FloatArray>> ExcludeCertainEdges(
const HeteroSubgraph& sg, const std::vector<IdArray>& exclude_edges, const HeteroSubgraph& sg, const std::vector<IdArray>& exclude_edges,
const std::vector<FloatArray>* weights = nullptr) { const std::vector<FloatArray>* weights = nullptr) {
...@@ -266,6 +339,242 @@ HeteroSubgraph SampleNeighbors( ...@@ -266,6 +339,242 @@ HeteroSubgraph SampleNeighbors(
return ret; return ret;
} }
template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
SampleNeighborsFused(
const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
const std::vector<IdArray>& mapping, const std::vector<int64_t>& fanouts,
EdgeDir dir, const std::vector<NDArray>& prob_or_mask,
const std::vector<IdArray>& exclude_edges, bool replace) {
CHECK_EQ(nodes.size(), hg->NumVertexTypes())
<< "Number of node ID tensors must match the number of node types.";
CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
<< "Number of fanout values must match the number of edge types.";
CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes())
<< "Number of probability tensors must match the number of edge types.";
DGLContext ctx = aten::GetContextOf(nodes);
std::vector<CSRMatrix> sampled_graphs;
std::vector<IdArray> sampled_coo_rows;
std::vector<IdArray> induced_edges;
std::vector<IdArray> induced_vertices;
std::vector<int64_t> num_nodes_per_type;
std::vector<std::vector<IdType>> new_nodes_vec(hg->NumVertexTypes());
std::vector<int> seed_nodes_mapped(hg->NumVertexTypes(), 0);
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const dgl_type_t rhs_node_type =
(dir == EdgeDir::kOut) ? src_vtype : dst_vtype;
const IdArray nodes_ntype = nodes[rhs_node_type];
const int64_t num_nodes = nodes_ntype->shape[0];
if (num_nodes == 0 || fanouts[etype] == 0) {
// Nothing to sample for this etype, create a placeholder
sampled_graphs.push_back(CSRMatrix());
sampled_coo_rows.push_back(IdArray());
induced_edges.push_back(aten::NullArray(hg->DataType(), ctx));
} else {
bool map_seed_nodes = !seed_nodes_mapped[rhs_node_type];
// sample from one relation graph
std::pair<CSRMatrix, IdArray> sampled_graph;
auto sampling_fn = map_seed_nodes
? aten::CSRRowWiseSamplingFused<IdType, true>
: aten::CSRRowWiseSamplingFused<IdType, false>;
auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt);
switch (avail_fmt) {
case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut)
<< "Cannot sample out edges on CSC matrix.";
// In heterographs nodes of two diffrent types can be connected
// therefore two diffrent mappings and node vectors are needed
sampled_graph = sampling_fn(
hg->GetCSRMatrix(etype), nodes_ntype, mapping[src_vtype],
&new_nodes_vec[src_vtype], fanouts[etype], prob_or_mask[etype],
replace);
break;
case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_graph = sampling_fn(
hg->GetCSCMatrix(etype), nodes_ntype, mapping[dst_vtype],
&new_nodes_vec[dst_vtype], fanouts[etype], prob_or_mask[etype],
replace);
break;
default:
LOG(FATAL) << "Unsupported sparse format.";
}
seed_nodes_mapped[rhs_node_type]++;
sampled_graphs.push_back(sampled_graph.first);
if (sampled_graph.first.data.defined())
induced_edges.push_back(sampled_graph.first.data);
else
induced_edges.push_back(
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
sampled_coo_rows.push_back(sampled_graph.second);
}
}
if (!exclude_edges.empty()) {
ExcludeCertainEdgesFused<IdType>(
&sampled_graphs, &induced_edges, &sampled_coo_rows, exclude_edges);
for (size_t i = 0; i < hg->NumEdgeTypes(); i++) {
if (sampled_graphs[i].data.defined())
induced_edges[i] = std::move(sampled_graphs[i].data);
else
induced_edges[i] =
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx);
}
}
// map indices
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const dgl_type_t lhs_node_type =
(dir == EdgeDir::kIn) ? src_vtype : dst_vtype;
if (sampled_graphs[etype].num_cols != 0) {
auto num_cols = sampled_graphs[etype].num_cols;
int num_threads_col = runtime::compute_num_threads(0, num_cols, 1);
std::vector<IdType> global_prefix_col(num_threads_col + 1, 0);
std::vector<std::vector<IdType>> src_nodes_local(num_threads_col);
IdType* mapping_data_dst = mapping[lhs_node_type].Ptr<IdType>();
IdType* cdata = sampled_graphs[etype].indices.Ptr<IdType>();
#pragma omp parallel num_threads(num_threads_col)
{
const int thread_id = omp_get_thread_num();
num_threads_col = omp_get_num_threads();
const int64_t start_i =
thread_id * (num_cols / num_threads_col) +
std::min(
static_cast<int64_t>(thread_id), num_cols % num_threads_col);
const int64_t end_i = (thread_id + 1) * (num_cols / num_threads_col) +
std::min(
static_cast<int64_t>(thread_id + 1),
num_cols % num_threads_col);
assert(thread_id + 1 < num_threads_col || end_i == num_cols);
for (int64_t i = start_i; i < end_i; ++i) {
int64_t picked_idx = cdata[i];
bool spot_claimed =
BoolCompareAndSwap<IdType>(&mapping_data_dst[picked_idx]);
if (spot_claimed) src_nodes_local[thread_id].push_back(picked_idx);
}
global_prefix_col[thread_id + 1] = src_nodes_local[thread_id].size();
#pragma omp barrier
#pragma omp master
{
global_prefix_col[0] = new_nodes_vec[lhs_node_type].size();
for (int t = 0; t < num_threads_col; ++t) {
global_prefix_col[t + 1] += global_prefix_col[t];
}
}
#pragma omp barrier
int64_t mapping_shift = global_prefix_col[thread_id];
for (size_t i = 0; i < src_nodes_local[thread_id].size(); ++i)
mapping_data_dst[src_nodes_local[thread_id][i]] = mapping_shift + i;
#pragma omp barrier
for (int64_t i = start_i; i < end_i; ++i) {
IdType picked_idx = cdata[i];
IdType mapped_idx = mapping_data_dst[picked_idx];
cdata[i] = mapped_idx;
}
}
IdType offset = new_nodes_vec[lhs_node_type].size();
new_nodes_vec[lhs_node_type].resize(global_prefix_col.back());
for (int thread_id = 0; thread_id < num_threads_col; ++thread_id) {
memcpy(
new_nodes_vec[lhs_node_type].data() + offset,
&src_nodes_local[thread_id][0],
src_nodes_local[thread_id].size() * sizeof(IdType));
offset += src_nodes_local[thread_id].size();
}
}
}
// counting how many nodes of each ntype were sampled
num_nodes_per_type.resize(2 * hg->NumVertexTypes());
for (size_t i = 0; i < hg->NumVertexTypes(); i++) {
num_nodes_per_type[i] = new_nodes_vec[i].size();
num_nodes_per_type[hg->NumVertexTypes() + i] = nodes[i]->shape[0];
induced_vertices.push_back(
VecToIdArray(new_nodes_vec[i], sizeof(IdType) * 8));
}
std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
if (sampled_graphs[etype].num_rows == 0) {
subrels[etype] = UnitGraph::Empty(
2, new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0],
hg->DataType(), ctx);
} else {
CSRMatrix graph = sampled_graphs[etype];
if (dir == EdgeDir::kOut) {
subrels[etype] = UnitGraph::CreateFromCSRAndCOO(
2,
CSRMatrix(
nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(),
graph.indptr, graph.indices,
Range(
0, graph.indices->shape[0], graph.indices->dtype.bits,
ctx)),
COOMatrix(
nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(),
sampled_coo_rows[etype], graph.indices),
ALL_CODE);
} else {
subrels[etype] = UnitGraph::CreateFromCSCAndCOO(
2,
CSRMatrix(
nodes[dst_vtype]->shape[0], new_nodes_vec[src_vtype].size(),
graph.indptr, graph.indices,
Range(
0, graph.indices->shape[0], graph.indices->dtype.bits,
ctx)),
COOMatrix(
new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0],
graph.indices, sampled_coo_rows[etype]),
ALL_CODE);
}
}
}
HeteroSubgraph ret;
const auto meta_graph = hg->meta_graph();
const EdgeArray etypes = meta_graph->Edges("eid");
const IdArray new_dst = Add(etypes.dst, hg->NumVertexTypes());
const auto new_meta_graph = ImmutableGraph::CreateFromCOO(
hg->NumVertexTypes() * 2, etypes.src, new_dst);
HeteroGraphPtr new_graph =
CreateHeteroGraph(new_meta_graph, subrels, num_nodes_per_type);
return std::make_tuple(new_graph, induced_edges, induced_vertices);
}
template std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
SampleNeighborsFused<int64_t>(
const HeteroGraphPtr, const std::vector<IdArray>&,
const std::vector<IdArray>&, const std::vector<int64_t>&, EdgeDir,
const std::vector<NDArray>&, const std::vector<IdArray>&, bool);
template std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
SampleNeighborsFused<int32_t>(
const HeteroGraphPtr, const std::vector<IdArray>&,
const std::vector<IdArray>&, const std::vector<int64_t>&, EdgeDir,
const std::vector<NDArray>&, const std::vector<IdArray>&, bool);
HeteroSubgraph SampleNeighborsEType( HeteroSubgraph SampleNeighborsEType(
const HeteroGraphPtr hg, const IdArray nodes, const HeteroGraphPtr hg, const IdArray nodes,
const std::vector<int64_t>& eid2etype_offset, const std::vector<int64_t>& eid2etype_offset,
...@@ -568,6 +877,47 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors") ...@@ -568,6 +877,47 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
*rv = HeteroSubgraphRef(subg); *rv = HeteroSubgraphRef(subg);
}); });
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsFused")
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
auto mapping = ListValueToVector<IdArray>(args[2]);
IdArray fanouts_array = args[3];
const auto& fanouts = fanouts_array.ToVector<int64_t>();
const std::string dir_str = args[4];
const auto& prob_or_mask = ListValueToVector<NDArray>(args[5]);
const auto& exclude_edges = ListValueToVector<IdArray>(args[6]);
const bool replace = args[7];
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
HeteroGraphPtr new_graph;
std::vector<IdArray> induced_edges;
std::vector<IdArray> induced_vertices;
ATEN_ID_TYPE_SWITCH(hg->DataType(), IdType, {
std::tie(new_graph, induced_edges, induced_vertices) =
SampleNeighborsFused<IdType>(
hg.sptr(), nodes, mapping, fanouts, dir, prob_or_mask,
exclude_edges, replace);
});
List<Value> lhs_nodes_ref;
for (IdArray& array : induced_vertices)
lhs_nodes_ref.push_back(Value(MakeValue(array)));
List<Value> induced_edges_ref;
for (IdArray& array : induced_edges)
induced_edges_ref.push_back(Value(MakeValue(array)));
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(new_graph));
ret.push_back(lhs_nodes_ref);
ret.push_back(induced_edges_ref);
*rv = ret;
});
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk") DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
......
...@@ -1218,6 +1218,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSR( ...@@ -1218,6 +1218,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSR(
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats)); return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
} }
HeteroGraphPtr UnitGraph::CreateFromCSRAndCOO(
int64_t num_vtypes, const aten::CSRMatrix& csr, const aten::COOMatrix& coo,
dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
CHECK_EQ(coo.num_rows, csr.num_rows);
CHECK_EQ(coo.num_cols, csr.num_cols);
if (num_vtypes == 1) {
CHECK_EQ(csr.num_rows, csr.num_cols);
}
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csrPtr(new CSR(mg, csr));
COOPtr cooPtr(new COO(mg, coo));
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csrPtr, cooPtr, formats));
}
HeteroGraphPtr UnitGraph::CreateFromCSC( HeteroGraphPtr UnitGraph::CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats) { IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
...@@ -1237,6 +1252,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSC( ...@@ -1237,6 +1252,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSC(
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats)); return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
} }
HeteroGraphPtr UnitGraph::CreateFromCSCAndCOO(
int64_t num_vtypes, const aten::CSRMatrix& csc, const aten::COOMatrix& coo,
dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
CHECK_EQ(coo.num_rows, csc.num_cols);
CHECK_EQ(coo.num_cols, csc.num_rows);
if (num_vtypes == 1) {
CHECK_EQ(csc.num_rows, csc.num_cols);
}
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr cscPtr(new CSR(mg, csc));
COOPtr cooPtr(new COO(mg, coo));
return HeteroGraphPtr(new UnitGraph(mg, cscPtr, nullptr, cooPtr, formats));
}
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
if (g->NumBits() == bits) { if (g->NumBits() == bits) {
return g; return g;
......
...@@ -190,6 +190,12 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -190,6 +190,12 @@ class UnitGraph : public BaseHeteroGraph {
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/** @brief Create a graph from (out) CSR and COO arrays, both representing the
* same graph */
static HeteroGraphPtr CreateFromCSRAndCOO(
int64_t num_vtypes, const aten::CSRMatrix& csr,
const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE);
/** @brief Create a graph from (in) CSC arrays */ /** @brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC( static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
...@@ -199,6 +205,12 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -199,6 +205,12 @@ class UnitGraph : public BaseHeteroGraph {
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/** @brief Create a graph from (in) CSC and COO arrays, both representing the
* same graph */
static HeteroGraphPtr CreateFromCSCAndCOO(
int64_t num_vtypes, const aten::CSRMatrix& csc,
const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE);
/** @brief Convert the graph to use the given number of bits for storage */ /** @brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
......
...@@ -7,6 +7,11 @@ import dgl ...@@ -7,6 +7,11 @@ import dgl
import numpy as np import numpy as np
import pytest import pytest
sample_neighbors_fusing_mode = {
True: dgl.sampling.sample_neighbors_fused,
False: dgl.sampling.sample_neighbors,
}
def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None): def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None):
traces = F.asnumpy(traces) traces = F.asnumpy(traces)
...@@ -555,15 +560,18 @@ def _gen_neighbor_topk_test_graph(hypersparse, reverse): ...@@ -555,15 +560,18 @@ def _gen_neighbor_topk_test_graph(hypersparse, reverse):
return g, hg return g, hg
def _test_sample_neighbors(hypersparse, prob): def _test_sample_neighbors(hypersparse, prob, fused):
g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False) g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False)
def _test1(p, replace): def _test1(p, replace):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
g, [0, 1], -1, prob=p, replace=replace g, [0, 1], -1, prob=p, replace=replace
) )
if not fused:
assert subg.num_nodes() == g.num_nodes() assert subg.num_nodes() == g.num_nodes()
u, v = subg.edges() u, v = subg.edges()
if fused:
u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all") u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all")
if p is not None: if p is not None:
emask = F.gather_row(g.edata[p], e_ans) emask = F.gather_row(g.edata[p], e_ans)
...@@ -576,12 +584,17 @@ def _test_sample_neighbors(hypersparse, prob): ...@@ -576,12 +584,17 @@ def _test_sample_neighbors(hypersparse, prob):
assert uv == uv_ans assert uv == uv_ans
for i in range(10): for i in range(10):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
g, [0, 1], 2, prob=p, replace=replace g, [0, 1], 2, prob=p, replace=replace
) )
if not fused:
assert subg.num_nodes() == g.num_nodes() assert subg.num_nodes() == g.num_nodes()
assert subg.num_edges() == 4 assert subg.num_edges() == 4
u, v = subg.edges() u, v = subg.edges()
if fused:
u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
assert set(F.asnumpy(F.unique(v))) == {0, 1} assert set(F.asnumpy(F.unique(v))) == {0, 1}
assert F.array_equal( assert F.array_equal(
F.astype(g.has_edges_between(u, v), F.int64), F.astype(g.has_edges_between(u, v), F.int64),
...@@ -600,11 +613,14 @@ def _test_sample_neighbors(hypersparse, prob): ...@@ -600,11 +613,14 @@ def _test_sample_neighbors(hypersparse, prob):
_test1(prob, False) # w/o replacement, uniform _test1(prob, False) # w/o replacement, uniform
def _test2(p, replace): # fanout > #neighbors def _test2(p, replace): # fanout > #neighbors
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
g, [0, 2], -1, prob=p, replace=replace g, [0, 2], -1, prob=p, replace=replace
) )
if not fused:
assert subg.num_nodes() == g.num_nodes() assert subg.num_nodes() == g.num_nodes()
u, v = subg.edges() u, v = subg.edges()
if fused:
u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all") u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all")
if p is not None: if p is not None:
emask = F.gather_row(g.edata[p], e_ans) emask = F.gather_row(g.edata[p], e_ans)
...@@ -617,13 +633,16 @@ def _test_sample_neighbors(hypersparse, prob): ...@@ -617,13 +633,16 @@ def _test_sample_neighbors(hypersparse, prob):
assert uv == uv_ans assert uv == uv_ans
for i in range(10): for i in range(10):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
g, [0, 2], 2, prob=p, replace=replace g, [0, 2], 2, prob=p, replace=replace
) )
if not fused:
assert subg.num_nodes() == g.num_nodes() assert subg.num_nodes() == g.num_nodes()
num_edges = 4 if replace else 3 num_edges = 4 if replace else 3
assert subg.num_edges() == num_edges assert subg.num_edges() == num_edges
u, v = subg.edges() u, v = subg.edges()
if fused:
u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
assert set(F.asnumpy(F.unique(v))) == {0, 2} assert set(F.asnumpy(F.unique(v))) == {0, 2}
assert F.array_equal( assert F.array_equal(
F.astype(g.has_edges_between(u, v), F.int64), F.astype(g.has_edges_between(u, v), F.int64),
...@@ -641,10 +660,13 @@ def _test_sample_neighbors(hypersparse, prob): ...@@ -641,10 +660,13 @@ def _test_sample_neighbors(hypersparse, prob):
_test2(prob, False) # w/o replacement, uniform _test2(prob, False) # w/o replacement, uniform
def _test3(p, replace): def _test3(p, replace):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
hg, {"user": [0, 1], "game": 0}, -1, prob=p, replace=replace hg, {"user": [0, 1], "game": 0}, -1, prob=p, replace=replace
) )
if not fused:
assert len(subg.ntypes) == 3 assert len(subg.ntypes) == 3
assert len(subg.srctypes) == 3
assert len(subg.dsttypes) == 3
assert len(subg.etypes) == 4 assert len(subg.etypes) == 4
assert subg["follow"].num_edges() == 6 if p is None else 4 assert subg["follow"].num_edges() == 6 if p is None else 4
assert subg["play"].num_edges() == 1 assert subg["play"].num_edges() == 1
...@@ -652,10 +674,13 @@ def _test_sample_neighbors(hypersparse, prob): ...@@ -652,10 +674,13 @@ def _test_sample_neighbors(hypersparse, prob):
assert subg["flips"].num_edges() == 0 assert subg["flips"].num_edges() == 0
for i in range(10): for i in range(10):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
hg, {"user": [0, 1], "game": 0}, 2, prob=p, replace=replace hg, {"user": [0, 1], "game": 0}, 2, prob=p, replace=replace
) )
if not fused:
assert len(subg.ntypes) == 3 assert len(subg.ntypes) == 3
assert len(subg.srctypes) == 3
assert len(subg.dsttypes) == 3
assert len(subg.etypes) == 4 assert len(subg.etypes) == 4
assert subg["follow"].num_edges() == 4 assert subg["follow"].num_edges() == 4
assert subg["play"].num_edges() == 2 if replace else 1 assert subg["play"].num_edges() == 2 if replace else 1
...@@ -667,13 +692,16 @@ def _test_sample_neighbors(hypersparse, prob): ...@@ -667,13 +692,16 @@ def _test_sample_neighbors(hypersparse, prob):
# test different fanouts for different relations # test different fanouts for different relations
for i in range(10): for i in range(10):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
hg, hg,
{"user": [0, 1], "game": 0, "coin": 0}, {"user": [0, 1], "game": 0, "coin": 0},
{"follow": 1, "play": 2, "liked-by": 0, "flips": -1}, {"follow": 1, "play": 2, "liked-by": 0, "flips": -1},
replace=True, replace=True,
) )
if not fused:
assert len(subg.ntypes) == 3 assert len(subg.ntypes) == 3
assert len(subg.srctypes) == 3
assert len(subg.dsttypes) == 3
assert len(subg.etypes) == 4 assert len(subg.etypes) == 4
assert subg["follow"].num_edges() == 2 assert subg["follow"].num_edges() == 2
assert subg["play"].num_edges() == 2 assert subg["play"].num_edges() == 2
...@@ -795,15 +823,19 @@ def _test_sample_labors(hypersparse, prob): ...@@ -795,15 +823,19 @@ def _test_sample_labors(hypersparse, prob):
assert subg["flips"].num_edges() == 4 assert subg["flips"].num_edges() == 4
def _test_sample_neighbors_outedge(hypersparse): def _test_sample_neighbors_outedge(hypersparse, fused):
g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True) g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True)
def _test1(p, replace): def _test1(p, replace):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
g, [0, 1], -1, prob=p, replace=replace, edge_dir="out" g, [0, 1], -1, prob=p, replace=replace, edge_dir="out"
) )
if not fused:
assert subg.num_nodes() == g.num_nodes() assert subg.num_nodes() == g.num_nodes()
u, v = subg.edges() u, v = subg.edges()
if fused:
u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
u_ans, v_ans, e_ans = g.out_edges([0, 1], form="all") u_ans, v_ans, e_ans = g.out_edges([0, 1], form="all")
if p is not None: if p is not None:
emask = F.gather_row(g.edata[p], e_ans) emask = F.gather_row(g.edata[p], e_ans)
...@@ -816,12 +848,15 @@ def _test_sample_neighbors_outedge(hypersparse): ...@@ -816,12 +848,15 @@ def _test_sample_neighbors_outedge(hypersparse):
assert uv == uv_ans assert uv == uv_ans
for i in range(10): for i in range(10):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
g, [0, 1], 2, prob=p, replace=replace, edge_dir="out" g, [0, 1], 2, prob=p, replace=replace, edge_dir="out"
) )
if not fused:
assert subg.num_nodes() == g.num_nodes() assert subg.num_nodes() == g.num_nodes()
assert subg.num_edges() == 4 assert subg.num_edges() == 4
u, v = subg.edges() u, v = subg.edges()
if fused:
u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
assert set(F.asnumpy(F.unique(u))) == {0, 1} assert set(F.asnumpy(F.unique(u))) == {0, 1}
assert F.array_equal( assert F.array_equal(
F.astype(g.has_edges_between(u, v), F.int64), F.astype(g.has_edges_between(u, v), F.int64),
...@@ -842,11 +877,14 @@ def _test_sample_neighbors_outedge(hypersparse): ...@@ -842,11 +877,14 @@ def _test_sample_neighbors_outedge(hypersparse):
_test1("prob", False) # w/o replacement _test1("prob", False) # w/o replacement
def _test2(p, replace): # fanout > #neighbors def _test2(p, replace): # fanout > #neighbors
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
g, [0, 2], -1, prob=p, replace=replace, edge_dir="out" g, [0, 2], -1, prob=p, replace=replace, edge_dir="out"
) )
if not fused:
assert subg.num_nodes() == g.num_nodes() assert subg.num_nodes() == g.num_nodes()
u, v = subg.edges() u, v = subg.edges()
if fused:
u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
u_ans, v_ans, e_ans = g.out_edges([0, 2], form="all") u_ans, v_ans, e_ans = g.out_edges([0, 2], form="all")
if p is not None: if p is not None:
emask = F.gather_row(g.edata[p], e_ans) emask = F.gather_row(g.edata[p], e_ans)
...@@ -859,13 +897,17 @@ def _test_sample_neighbors_outedge(hypersparse): ...@@ -859,13 +897,17 @@ def _test_sample_neighbors_outedge(hypersparse):
assert uv == uv_ans assert uv == uv_ans
for i in range(10): for i in range(10):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
g, [0, 2], 2, prob=p, replace=replace, edge_dir="out" g, [0, 2], 2, prob=p, replace=replace, edge_dir="out"
) )
if not fused:
assert subg.num_nodes() == g.num_nodes() assert subg.num_nodes() == g.num_nodes()
num_edges = 4 if replace else 3 num_edges = 4 if replace else 3
assert subg.num_edges() == num_edges assert subg.num_edges() == num_edges
u, v = subg.edges() u, v = subg.edges()
if fused:
u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
assert set(F.asnumpy(F.unique(u))) == {0, 2} assert set(F.asnumpy(F.unique(u))) == {0, 2}
assert F.array_equal( assert F.array_equal(
F.astype(g.has_edges_between(u, v), F.int64), F.astype(g.has_edges_between(u, v), F.int64),
...@@ -885,7 +927,7 @@ def _test_sample_neighbors_outedge(hypersparse): ...@@ -885,7 +927,7 @@ def _test_sample_neighbors_outedge(hypersparse):
_test2("prob", False) # w/o replacement _test2("prob", False) # w/o replacement
def _test3(p, replace): def _test3(p, replace):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
hg, hg,
{"user": [0, 1], "game": 0}, {"user": [0, 1], "game": 0},
-1, -1,
...@@ -893,7 +935,11 @@ def _test_sample_neighbors_outedge(hypersparse): ...@@ -893,7 +935,11 @@ def _test_sample_neighbors_outedge(hypersparse):
replace=replace, replace=replace,
edge_dir="out", edge_dir="out",
) )
if not fused:
assert len(subg.ntypes) == 3 assert len(subg.ntypes) == 3
assert len(subg.srctypes) == 3
assert len(subg.dsttypes) == 3
assert len(subg.etypes) == 4 assert len(subg.etypes) == 4
assert subg["follow"].num_edges() == 6 if p is None else 4 assert subg["follow"].num_edges() == 6 if p is None else 4
assert subg["play"].num_edges() == 1 assert subg["play"].num_edges() == 1
...@@ -901,7 +947,7 @@ def _test_sample_neighbors_outedge(hypersparse): ...@@ -901,7 +947,7 @@ def _test_sample_neighbors_outedge(hypersparse):
assert subg["flips"].num_edges() == 0 assert subg["flips"].num_edges() == 0
for i in range(10): for i in range(10):
subg = dgl.sampling.sample_neighbors( subg = sample_neighbors_fusing_mode[fused](
hg, hg,
{"user": [0, 1], "game": 0}, {"user": [0, 1], "game": 0},
2, 2,
...@@ -909,7 +955,10 @@ def _test_sample_neighbors_outedge(hypersparse): ...@@ -909,7 +955,10 @@ def _test_sample_neighbors_outedge(hypersparse):
replace=replace, replace=replace,
edge_dir="out", edge_dir="out",
) )
if not fused:
assert len(subg.ntypes) == 3 assert len(subg.ntypes) == 3
assert len(subg.srctypes) == 3
assert len(subg.dsttypes) == 3
assert len(subg.etypes) == 4 assert len(subg.etypes) == 4
assert subg["follow"].num_edges() == 4 assert subg["follow"].num_edges() == 4
assert subg["play"].num_edges() == 2 if replace else 1 assert subg["play"].num_edges() == 2 if replace else 1
...@@ -1077,7 +1126,9 @@ def _test_sample_neighbors_topk_outedge(hypersparse): ...@@ -1077,7 +1126,9 @@ def _test_sample_neighbors_topk_outedge(hypersparse):
def test_sample_neighbors_noprob(): def test_sample_neighbors_noprob():
_test_sample_neighbors(False, None) _test_sample_neighbors(False, None, False)
if F._default_context_str != "gpu" and F.backend_name == "pytorch":
_test_sample_neighbors(False, None, True)
# _test_sample_neighbors(True) # _test_sample_neighbors(True)
...@@ -1086,7 +1137,9 @@ def test_sample_labors_noprob(): ...@@ -1086,7 +1137,9 @@ def test_sample_labors_noprob():
def test_sample_neighbors_prob(): def test_sample_neighbors_prob():
_test_sample_neighbors(False, "prob") _test_sample_neighbors(False, "prob", False)
if F._default_context_str != "gpu" and F.backend_name == "pytorch":
_test_sample_neighbors(False, "prob", True)
# _test_sample_neighbors(True) # _test_sample_neighbors(True)
...@@ -1095,7 +1148,9 @@ def test_sample_labors_prob(): ...@@ -1095,7 +1148,9 @@ def test_sample_labors_prob():
def test_sample_neighbors_outedge(): def test_sample_neighbors_outedge():
_test_sample_neighbors_outedge(False) _test_sample_neighbors_outedge(False, False)
if F._default_context_str != "gpu" and F.backend_name == "pytorch":
_test_sample_neighbors_outedge(False, True)
# _test_sample_neighbors_outedge(True) # _test_sample_neighbors_outedge(True)
...@@ -1107,7 +1162,9 @@ def test_sample_neighbors_outedge(): ...@@ -1107,7 +1162,9 @@ def test_sample_neighbors_outedge():
reason="GPU sample neighbors with mask not implemented", reason="GPU sample neighbors with mask not implemented",
) )
def test_sample_neighbors_mask(): def test_sample_neighbors_mask():
_test_sample_neighbors(False, "mask") _test_sample_neighbors(False, "mask", False)
if F._default_context_str != "gpu" and F.backend_name == "pytorch":
_test_sample_neighbors(False, "mask", True)
@unittest.skipIf( @unittest.skipIf(
...@@ -1128,21 +1185,26 @@ def test_sample_neighbors_topk_outedge(): ...@@ -1128,21 +1185,26 @@ def test_sample_neighbors_topk_outedge():
# _test_sample_neighbors_topk_outedge(True) # _test_sample_neighbors_topk_outedge(True)
def test_sample_neighbors_with_0deg(): @pytest.mark.parametrize("fused", [False, True])
def test_sample_neighbors_with_0deg(fused):
if fused and (
F._default_context_str == "gpu" or F.backend_name != "pytorch"
):
pytest.skip("Fused sampling support CPU with backend PyTorch.")
g = dgl.graph(([], []), num_nodes=5).to(F.ctx()) g = dgl.graph(([], []), num_nodes=5).to(F.ctx())
sg = dgl.sampling.sample_neighbors( sg = sample_neighbors_fusing_mode[fused](
g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=False g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=False
) )
assert sg.num_edges() == 0 assert sg.num_edges() == 0
sg = dgl.sampling.sample_neighbors( sg = sample_neighbors_fusing_mode[fused](
g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=True g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=True
) )
assert sg.num_edges() == 0 assert sg.num_edges() == 0
sg = dgl.sampling.sample_neighbors( sg = sample_neighbors_fusing_mode[fused](
g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=False g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=False
) )
assert sg.num_edges() == 0 assert sg.num_edges() == 0
sg = dgl.sampling.sample_neighbors( sg = sample_neighbors_fusing_mode[fused](
g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=True g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=True
) )
assert sg.num_edges() == 0 assert sg.num_edges() == 0
...@@ -1274,7 +1336,7 @@ def test_sample_neighbors_biased_homogeneous(): ...@@ -1274,7 +1336,7 @@ def test_sample_neighbors_biased_homogeneous():
) )
def test_sample_neighbors_biased_bipartite(): def test_sample_neighbors_biased_bipartite():
g = create_test_graph(100, 30, True) g = create_test_graph(100, 30, True)
num_dst = g.number_of_dst_nodes() num_dst = g.num_dst_nodes()
bias = F.tensor([0, 0.01, 10, 10], dtype=F.float32) bias = F.tensor([0, 0.01, 10, 10], dtype=F.float32)
def check_num(nodes, tag): def check_num(nodes, tag):
...@@ -1492,7 +1554,12 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction): ...@@ -1492,7 +1554,12 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
@pytest.mark.parametrize("dtype", ["int32", "int64"]) @pytest.mark.parametrize("dtype", ["int32", "int64"])
def test_sample_neighbors_exclude_edges_heteroG(dtype): @pytest.mark.parametrize("fused", [False, True])
def test_sample_neighbors_exclude_edges_heteroG(dtype, fused):
if fused and (
F._default_context_str == "gpu" or F.backend_name != "pytorch"
):
pytest.skip("Fused sampling support CPU with backend PyTorch.")
d_i_d_u_nodes = F.zerocopy_from_numpy( d_i_d_u_nodes = F.zerocopy_from_numpy(
np.unique(np.random.randint(300, size=100, dtype=dtype)) np.unique(np.random.randint(300, size=100, dtype=dtype))
) )
...@@ -1565,7 +1632,7 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype): ...@@ -1565,7 +1632,7 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype):
("drug", "treats", "disease"): excluded_d_t_d_edges, ("drug", "treats", "disease"): excluded_d_t_d_edges,
} }
sg = dgl.sampling.sample_neighbors( sg = sample_neighbors_fusing_mode[fused](
g, g,
{ {
"drug": sampled_drug_node, "drug": sampled_drug_node,
...@@ -1576,6 +1643,48 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype): ...@@ -1576,6 +1643,48 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype):
exclude_edges=excluded_edges, exclude_edges=excluded_edges,
) )
if fused:
def contain_edge(g, sg, etype, u, v):
# set of subgraph graph edges deduced from original graph
org_edges = set(
map(
tuple,
np.stack(
g.find_edges(sg.edges[etype].data[dgl.EID], etype),
axis=1,
),
)
)
# set of excluded edges
excluded_edges = set(map(tuple, np.stack((u, v), axis=1)))
diff_set = org_edges - excluded_edges
return len(diff_set) != len(org_edges)
assert not contain_edge(
g,
sg,
("drug", "interacts", "drug"),
did_excluded_nodes_U,
did_excluded_nodes_V,
)
assert not contain_edge(
g,
sg,
("drug", "interacts", "gene"),
dig_excluded_nodes_U,
dig_excluded_nodes_V,
)
assert not contain_edge(
g,
sg,
("drug", "treats", "disease"),
dtd_excluded_nodes_U,
dtd_excluded_nodes_V,
)
else:
assert not np.any( assert not np.any(
F.asnumpy( F.asnumpy(
sg.has_edges_between( sg.has_edges_between(
...@@ -1606,7 +1715,12 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype): ...@@ -1606,7 +1715,12 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype):
@pytest.mark.parametrize("dtype", ["int32", "int64"]) @pytest.mark.parametrize("dtype", ["int32", "int64"])
def test_sample_neighbors_exclude_edges_homoG(dtype): @pytest.mark.parametrize("fused", [False, True])
def test_sample_neighbors_exclude_edges_homoG(dtype, fused):
if fused and (
F._default_context_str == "gpu" or F.backend_name != "pytorch"
):
pytest.skip("Fused sampling support CPU with backend PyTorch.")
u_nodes = F.zerocopy_from_numpy( u_nodes = F.zerocopy_from_numpy(
np.unique(np.random.randint(300, size=100, dtype=dtype)) np.unique(np.random.randint(300, size=100, dtype=dtype))
) )
...@@ -1629,10 +1743,30 @@ def test_sample_neighbors_exclude_edges_homoG(dtype): ...@@ -1629,10 +1743,30 @@ def test_sample_neighbors_exclude_edges_homoG(dtype):
excluded_nodes_U = g_edges[U][b_idx:e_idx] excluded_nodes_U = g_edges[U][b_idx:e_idx]
excluded_nodes_V = g_edges[V][b_idx:e_idx] excluded_nodes_V = g_edges[V][b_idx:e_idx]
sg = dgl.sampling.sample_neighbors( sg = sample_neighbors_fusing_mode[fused](
g, sampled_node, sampled_amount, exclude_edges=excluded_edges g, sampled_node, sampled_amount, exclude_edges=excluded_edges
) )
if fused:
def contain_edge(g, sg, u, v):
# set of subgraph graph edges deduced from original graph
org_edges = set(
map(
tuple,
np.stack(
g.find_edges(sg.edges["_E"].data[dgl.EID]), axis=1
),
)
)
# set of excluded edges
excluded_edges = set(map(tuple, np.stack((u, v), axis=1)))
diff_set = org_edges - excluded_edges
return len(diff_set) != len(org_edges)
assert not contain_edge(g, sg, excluded_nodes_U, excluded_nodes_V)
else:
assert not np.any( assert not np.any(
F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V)) F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V))
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment