"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "56f740051dae2d410677292a5c9e5b66e60f87dc"
Unverified Commit 2bca4759 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sampling] Enable sampling with edge masks in sample_etype_neighbors (#4749)

* sample neighbors with masks

* oops

* refactor again

* remove

* remove debug code

* rename macro

* address comments

* more stuff

* remove

* fix

* try fix unit test

* oops

* fix test

* oops

* change name

* rename a lot of stuff

* oops

* ugh

* misc fixes

* lint

* address a lot of comments

* lint

* lint

* fix

* that was silly

* fix

* fix

* fix

* oops
parent 72781efb
...@@ -429,11 +429,11 @@ COOMatrix COORowWiseSampling( ...@@ -429,11 +429,11 @@ COOMatrix COORowWiseSampling(
* // coo.rows = [0, 0, 0, 0, 3] * // coo.rows = [0, 0, 0, 0, 3]
* // coo.cols = [0, 1, 3, 2, 3] * // coo.cols = [0, 1, 3, 2, 3]
* // coo.data = [2, 3, 0, 1, 4] * // coo.data = [2, 3, 0, 1, 4]
* // etype = [0, 0, 0, 2, 1] * // eid2etype_offset = [0, 3, 4, 5]
* COOMatrix coo = ...; * COOMatrix coo = ...;
* IdArray rows = ... ; // [0, 3] * IdArray rows = ... ; // [0, 3]
* std::vector<int64_t> num_samples = {2, 2, 2}; * std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, etype, num_samples, * COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, eid2etype_offset, num_samples,
* FloatArray(), false); * FloatArray(), false);
* // possible sampled coo matrix: * // possible sampled coo matrix:
* // sampled.num_rows = 4 * // sampled.num_rows = 4
...@@ -444,23 +444,23 @@ COOMatrix COORowWiseSampling( ...@@ -444,23 +444,23 @@ COOMatrix COORowWiseSampling(
* *
* \param mat Input coo matrix. * \param mat Input coo matrix.
* \param rows Rows to sample from. * \param rows Rows to sample from.
* \param etypes Edge types of each edge. * \param eid2etype_offset The offset to each edge type.
* \param num_samples Number of samples * \param num_samples Number of samples
* \param prob Unnormalized probability array. Should be of the same length as the data array. * \param prob_or_mask Unnormalized probability array or mask array.
* If an empty array is provided, assume uniform. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement * \param replace True if sample with replacement
* \param etype_sorted True if the edge types are already sorted
* \return A COOMatrix storing the picked row and col indices. Its data field stores the * \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array. * the index of the picked elements in the value array.
* \note The edges of the entire graph must be ordered by their edge types.
*/ */
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, COOMatrix mat,
IdArray rows, IdArray rows,
IdArray etypes, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
FloatArray prob = FloatArray(), const std::vector<NDArray>& prob_or_mask,
bool replace = true, bool replace = true);
bool etype_sorted = false);
/*! /*!
* \brief Select K non-zero entries with the largest weights along each given row. * \brief Select K non-zero entries with the largest weights along each given row.
......
...@@ -428,6 +428,7 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); ...@@ -428,6 +428,7 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* \param replace True if sample with replacement * \param replace True if sample with replacement
* \return A COOMatrix storing the picked row, col and data indices. * \return A COOMatrix storing the picked row, col and data indices.
* \note The edges of the entire graph must be ordered by their edge types.
*/ */
COOMatrix CSRRowWiseSampling( COOMatrix CSRRowWiseSampling(
CSRMatrix mat, CSRMatrix mat,
...@@ -455,11 +456,11 @@ COOMatrix CSRRowWiseSampling( ...@@ -455,11 +456,11 @@ COOMatrix CSRRowWiseSampling(
* // csr.indptr = [0, 4, 4, 4, 5] * // csr.indptr = [0, 4, 4, 4, 5]
* // csr.cols = [0, 1, 3, 2, 3] * // csr.cols = [0, 1, 3, 2, 3]
* // csr.data = [2, 3, 0, 1, 4] * // csr.data = [2, 3, 0, 1, 4]
* // etype = [0, 0, 0, 2, 1] * // eid2etype_offset = [0, 3, 4, 5]
* CSRMatrix csr = ...; * CSRMatrix csr = ...;
* IdArray rows = ... ; // [0, 3] * IdArray rows = ... ; // [0, 3]
* std::vector<int64_t> num_samples = {2, 2, 2}; * std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = CSRRowWisePerEtypeSampling(csr, rows, etype, num_samples, * COOMatrix sampled = CSRRowWisePerEtypeSampling(csr, rows, eid2etype_offset, num_samples,
* FloatArray(), false); * FloatArray(), false);
* // possible sampled coo matrix: * // possible sampled coo matrix:
* // sampled.num_rows = 4 * // sampled.num_rows = 4
...@@ -470,22 +471,24 @@ COOMatrix CSRRowWiseSampling( ...@@ -470,22 +471,24 @@ COOMatrix CSRRowWiseSampling(
* *
* \param mat Input CSR matrix. * \param mat Input CSR matrix.
* \param rows Rows to sample from. * \param rows Rows to sample from.
* \param etypes Edge types of each edge. * \param eid2etype_offset The offset to each edge type.
* \param num_samples Number of samples to choose per edge type. * \param num_samples Number of samples to choose per edge type.
* \param prob Unnormalized probability array. Should be of the same length as the data array. * \param prob_or_mask Unnormalized probability array or mask array.
* If an empty array is provided, assume uniform. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement * \param replace True if sample with replacement
* \param etype_sorted True if the edge types are already sorted * \param rowwise_etype_sorted whether the CSR column indices per row are ordered by edge type.
* \return A COOMatrix storing the picked row, col and data indices. * \return A COOMatrix storing the picked row, col and data indices.
* \note The edges must be ordered by their edge types.
*/ */
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, CSRMatrix mat,
IdArray rows, IdArray rows,
IdArray etypes, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
FloatArray prob = FloatArray(), const std::vector<NDArray>& prob_or_mask,
bool replace = true, bool replace = true,
bool etype_sorted = false); bool rowwise_etype_sorted = false);
/*! /*!
* \brief Select K non-zero entries with the largest weights along each given row. * \brief Select K non-zero entries with the largest weights along each given row.
......
...@@ -28,6 +28,17 @@ class NeighborSampler(BlockSampler): ...@@ -28,6 +28,17 @@ class NeighborSampler(BlockSampler):
If given, the probability of each neighbor being sampled is proportional If given, the probability of each neighbor being sampled is proportional
to the edge feature value with the given name in ``g.edata``. The feature must be to the edge feature value with the given name in ``g.edata``. The feature must be
a scalar on each edge. a scalar on each edge.
This argument is mutually exclusive with :attr:`mask`. If you want to
specify both a mask and a probability, consider multiplying the probability
with the mask instead.
mask : str, optional
If given, a neighbor could be picked only if the edge mask with the given
name in ``g.edata`` is True. The data must be boolean on each edge.
This argument is mutually exclusive with :attr:`prob`. If you want to
specify both a mask and a probability, consider multiplying the probability
with the mask instead.
replace : bool, default False replace : bool, default False
Whether to sample with replacement Whether to sample with replacement
prefetch_node_feats : list[str] or dict[ntype, list[str]], optional prefetch_node_feats : list[str] or dict[ntype, list[str]], optional
...@@ -72,6 +83,11 @@ class NeighborSampler(BlockSampler): ...@@ -72,6 +83,11 @@ class NeighborSampler(BlockSampler):
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works >>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='p') >>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='p')
Or sampling on edge masks:
>>> g.edata['mask'] = torch.rand(g.num_edges()) < 0.2 # any 1D boolean mask works
>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='mask')
**Edge classification and link prediction** **Edge classification and link prediction**
This class can also work for edge classification and link prediction together This class can also work for edge classification and link prediction together
...@@ -91,7 +107,7 @@ class NeighborSampler(BlockSampler): ...@@ -91,7 +107,7 @@ class NeighborSampler(BlockSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
def __init__(self, fanouts, edge_dir='in', prob=None, replace=False, def __init__(self, fanouts, edge_dir='in', prob=None, mask=None, replace=False,
prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None,
output_device=None): output_device=None):
super().__init__(prefetch_node_feats=prefetch_node_feats, super().__init__(prefetch_node_feats=prefetch_node_feats,
...@@ -100,7 +116,12 @@ class NeighborSampler(BlockSampler): ...@@ -100,7 +116,12 @@ class NeighborSampler(BlockSampler):
output_device=output_device) output_device=output_device)
self.fanouts = fanouts self.fanouts = fanouts
self.edge_dir = edge_dir self.edge_dir = edge_dir
self.prob = prob if mask is not None and prob is not None:
raise ValueError(
'Mask and probability arguments are mutually exclusive. '
'Consider multiplying the probability with the mask '
'to achieve the same goal.')
self.prob = prob or mask
self.replace = replace self.replace = replace
def sample_blocks(self, g, seed_nodes, exclude_eids=None): def sample_blocks(self, g, seed_nodes, exclude_eids=None):
......
...@@ -83,7 +83,10 @@ def _copy_graph_to_shared_mem(g, graph_name, graph_format): ...@@ -83,7 +83,10 @@ def _copy_graph_to_shared_mem(g, graph_name, graph_format):
# for heterogeneous graph, we need to put ETYPE into KVStore # for heterogeneous graph, we need to put ETYPE into KVStore
# for homogeneous graph, ETYPE does not exist # for homogeneous graph, ETYPE does not exist
if ETYPE in g.edata: if ETYPE in g.edata:
new_g.edata[ETYPE] = _to_shared_mem(g.edata[ETYPE], _get_edata_path(graph_name, ETYPE)) new_g.edata[ETYPE] = _to_shared_mem(
g.edata[ETYPE],
_get_edata_path(graph_name, ETYPE),
)
return new_g return new_g
def _get_shared_mem_ndata(g, graph_name, name): def _get_shared_mem_ndata(g, graph_name, name):
...@@ -378,6 +381,7 @@ class DistGraphServer(KVServer): ...@@ -378,6 +381,7 @@ class DistGraphServer(KVServer):
# The feature name has the following format: edge_type + "/" + feature_name to avoid # The feature name has the following format: edge_type + "/" + feature_name to avoid
# feature name collision for different edge types. # feature name collision for different edge types.
etype, feat_name = name.split('/') etype, feat_name = name.split('/')
data_name = HeteroDataName(False, etype, feat_name) data_name = HeteroDataName(False, etype, feat_name)
self.init_data(name=str(data_name), policy_str=data_name.policy_str, self.init_data(name=str(data_name), policy_str=data_name.policy_str,
data_tensor=edge_feats[name]) data_tensor=edge_feats[name])
...@@ -1264,13 +1268,13 @@ class DistGraph: ...@@ -1264,13 +1268,13 @@ class DistGraph:
output_device=None): output_device=None):
# pylint: disable=unused-argument # pylint: disable=unused-argument
"""Sample neighbors from a distributed graph.""" """Sample neighbors from a distributed graph."""
# Currently prob, exclude_edges, output_device, and edge_dir are ignored.
if len(self.etypes) > 1: if len(self.etypes) > 1:
frontier = graph_services.sample_etype_neighbors( frontier = graph_services.sample_etype_neighbors(
self, seed_nodes, ETYPE, fanout, replace=replace, etype_sorted=etype_sorted) self, seed_nodes, fanout, replace=replace,
etype_sorted=etype_sorted, prob=prob)
else: else:
frontier = graph_services.sample_neighbors( frontier = graph_services.sample_neighbors(
self, seed_nodes, fanout, replace=replace) self, seed_nodes, fanout, replace=replace, prob=prob)
return frontier return frontier
def _get_ndata_names(self, ntype=None): def _get_ndata_names(self, ntype=None):
......
...@@ -181,6 +181,16 @@ class DistTensor: ...@@ -181,6 +181,16 @@ class DistTensor:
# TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1). # TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1).
self.kvstore.push(name=self._name, id_tensor=idx, data_tensor=val) self.kvstore.push(name=self._name, id_tensor=idx, data_tensor=val)
@property
def kvstore_key(self):
"""Return the key string of this DistTensor in the associated KVStore."""
return self._name
@property
def local_partition(self):
"""Return the local partition of this DistTensor."""
return self.kvstore.data_store[self._name]
def __or__(self, other): def __or__(self, other):
new_dist_tensor = DistTensor( new_dist_tensor = DistTensor(
self._shape, self._shape,
......
...@@ -874,6 +874,18 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -874,6 +874,18 @@ class RangePartitionBook(GraphPartitionBook):
self._nid_map = IdMap(self._typed_nid_range) self._nid_map = IdMap(self._typed_nid_range)
self._eid_map = IdMap(self._typed_eid_range) self._eid_map = IdMap(self._typed_eid_range)
# Local node/edge type offset that maps the local homogenized node/edge IDs
# to local heterogenized node/edge IDs. One can do the mapping by binary search
# on these arrays.
self._local_ntype_offset = np.cumsum(
[0] + [
v[self._partid, 1] - v[self._partid, 0]
for v in self._typed_nid_range.values()]).tolist()
self._local_etype_offset = np.cumsum(
[0] + [
v[self._partid, 1] - v[self._partid, 0]
for v in self._typed_eid_range.values()]).tolist()
# Get meta data of the partition book # Get meta data of the partition book
self._partition_meta_data = [] self._partition_meta_data = []
for partid in range(self._num_partitions): for partid in range(self._num_partitions):
...@@ -1106,6 +1118,22 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -1106,6 +1118,22 @@ class RangePartitionBook(GraphPartitionBook):
""" """
return self._canonical_etypes return self._canonical_etypes
@property
def local_ntype_offset(self):
"""Get the node type offset array of the local partition.
The i-th element indicates the starting position of the i-th node type.
"""
return self._local_ntype_offset
@property
def local_etype_offset(self):
"""Get the edge type offset array of the local partition.
The i-th element indicates the starting position of the i-th edge type.
"""
return self._local_etype_offset
def _to_canonical_etype(self, etype): def _to_canonical_etype(self, etype):
"""Convert an edge type to the corresponding canonical edge type. """Convert an edge type to the corresponding canonical edge type.
If canonical etype is not available, no conversion is applied. If canonical etype is not available, no conversion is applied.
......
...@@ -10,6 +10,7 @@ from ..sampling import sample_etype_neighbors as local_sample_etype_neighbors ...@@ -10,6 +10,7 @@ from ..sampling import sample_etype_neighbors as local_sample_etype_neighbors
from ..sampling import sample_neighbors as local_sample_neighbors from ..sampling import sample_neighbors as local_sample_neighbors
from ..subgraph import in_subgraph as local_in_subgraph from ..subgraph import in_subgraph as local_in_subgraph
from ..utils import toindex from ..utils import toindex
from .. import backend as F
from .rpc import ( from .rpc import (
Request, Request,
Response, Response,
...@@ -98,7 +99,7 @@ def _sample_etype_neighbors( ...@@ -98,7 +99,7 @@ def _sample_etype_neighbors(
local_g, local_g,
partition_book, partition_book,
seed_nodes, seed_nodes,
etype_field, etype_offset,
fan_out, fan_out,
edge_dir, edge_dir,
prob, prob,
...@@ -118,7 +119,7 @@ def _sample_etype_neighbors( ...@@ -118,7 +119,7 @@ def _sample_etype_neighbors(
sampled_graph = local_sample_etype_neighbors( sampled_graph = local_sample_etype_neighbors(
local_g, local_g,
local_ids, local_ids,
etype_field, etype_offset,
fan_out, fan_out,
edge_dir, edge_dir,
prob, prob,
...@@ -181,6 +182,31 @@ def _in_subgraph(local_g, partition_book, seed_nodes): ...@@ -181,6 +182,31 @@ def _in_subgraph(local_g, partition_book, seed_nodes):
return global_src, global_dst, global_eids return global_src, global_dst, global_eids
# --- NOTE 1 ---
# (BarclayII)
# If the sampling algorithm needs node and edge data, ideally the
# algorithm should query the underlying feature storage to get what it
# just needs to complete the job. For instance, with
# sample_etype_neighbors, we only need the probability of the seed nodes'
# neighbors.
#
# However, right now we are reusing the existing subgraph sampling
# interfaces of DGLGraph (i.e. single machine solution), which needs
# the data of *all* the nodes/edges. Going distributed, we now need
# the node/edge data of the *entire* local graph partition.
#
# If the sampling algorithm only use edge data, the current design works
# because the local graph partition contains all the in-edges of the
# assigned nodes as well as the data. This is the case for
# sample_etype_neighbors.
#
# However, if the sampling algorithm requires data of the neighbor nodes
# (e.g. sample_neighbors_biased which performs biased sampling based on the
# type of the neighbor nodes), the current design will fail because the
# neighbor nodes (hence the data) may not belong to the current partition.
# This is a limitation of the current DistDGL design. We should improve it
# later.
class SamplingRequest(Request): class SamplingRequest(Request):
"""Sampling Request""" """Sampling Request"""
...@@ -212,13 +238,18 @@ class SamplingRequest(Request): ...@@ -212,13 +238,18 @@ class SamplingRequest(Request):
def process_request(self, server_state): def process_request(self, server_state):
local_g = server_state.graph local_g = server_state.graph
partition_book = server_state.partition_book partition_book = server_state.partition_book
kv_store = server_state.kv_store
if self.prob is not None:
prob = [kv_store.data_store[self.prob]]
else:
prob = None
global_src, global_dst, global_eids = _sample_neighbors( global_src, global_dst, global_eids = _sample_neighbors(
local_g, local_g,
partition_book, partition_book,
self.seed_nodes, self.seed_nodes,
self.fan_out, self.fan_out,
self.edge_dir, self.edge_dir,
self.prob, prob,
self.replace, self.replace,
) )
return SubgraphResponse(global_src, global_dst, global_eids) return SubgraphResponse(global_src, global_dst, global_eids)
...@@ -230,7 +261,6 @@ class SamplingRequestEtype(Request): ...@@ -230,7 +261,6 @@ class SamplingRequestEtype(Request):
def __init__( def __init__(
self, self,
nodes, nodes,
etype_field,
fan_out, fan_out,
edge_dir="in", edge_dir="in",
prob=None, prob=None,
...@@ -242,7 +272,6 @@ class SamplingRequestEtype(Request): ...@@ -242,7 +272,6 @@ class SamplingRequestEtype(Request):
self.prob = prob self.prob = prob
self.replace = replace self.replace = replace
self.fan_out = fan_out self.fan_out = fan_out
self.etype_field = etype_field
self.etype_sorted = etype_sorted self.etype_sorted = etype_sorted
def __setstate__(self, state): def __setstate__(self, state):
...@@ -252,7 +281,6 @@ class SamplingRequestEtype(Request): ...@@ -252,7 +281,6 @@ class SamplingRequestEtype(Request):
self.prob, self.prob,
self.replace, self.replace,
self.fan_out, self.fan_out,
self.etype_field,
self.etype_sorted, self.etype_sorted,
) = state ) = state
...@@ -263,21 +291,30 @@ class SamplingRequestEtype(Request): ...@@ -263,21 +291,30 @@ class SamplingRequestEtype(Request):
self.prob, self.prob,
self.replace, self.replace,
self.fan_out, self.fan_out,
self.etype_field,
self.etype_sorted, self.etype_sorted,
) )
def process_request(self, server_state): def process_request(self, server_state):
local_g = server_state.graph local_g = server_state.graph
partition_book = server_state.partition_book partition_book = server_state.partition_book
kv_store = server_state.kv_store
etype_offset = partition_book.local_etype_offset
# See NOTE 1
if self.prob is not None:
probs = [
kv_store.data_store[key] if key != "" else None
for key in self.prob
]
else:
probs = None
global_src, global_dst, global_eids = _sample_etype_neighbors( global_src, global_dst, global_eids = _sample_etype_neighbors(
local_g, local_g,
partition_book, partition_book,
self.seed_nodes, self.seed_nodes,
self.etype_field, etype_offset,
self.fan_out, self.fan_out,
self.edge_dir, self.edge_dir,
self.prob, probs,
self.replace, self.replace,
self.etype_sorted, self.etype_sorted,
) )
...@@ -536,7 +573,6 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb): ...@@ -536,7 +573,6 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
def sample_etype_neighbors( def sample_etype_neighbors(
g, g,
nodes, nodes,
etype_field,
fanout, fanout,
edge_dir="in", edge_dir="in",
prob=None, prob=None,
...@@ -552,8 +588,8 @@ def sample_etype_neighbors( ...@@ -552,8 +588,8 @@ def sample_etype_neighbors(
Node/edge features are not preserved. The original IDs of Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph. the sampled edges are stored as the `dgl.EID` feature in the returned graph.
This function assumes the input is a homogeneous ``DGLGraph`` with the TRUE edge type This function assumes the input is a homogeneous ``DGLGraph`` with the edges
information stored as the edge data in `etype_field`. The sampled subgraph is also ordered by their edge types. The sampled subgraph is also
stored in the homogeneous graph format. That is, all nodes and edges are assigned stored in the homogeneous graph format. That is, all nodes and edges are assigned
with unique IDs (in contrast, we typically use a type name and a node/edge ID to with unique IDs (in contrast, we typically use a type name and a node/edge ID to
identify a node or an edge in ``DGLGraph``). We refer to this type of IDs identify a node or an edge in ``DGLGraph``). We refer to this type of IDs
...@@ -569,8 +605,6 @@ def sample_etype_neighbors( ...@@ -569,8 +605,6 @@ def sample_etype_neighbors(
nodes : tensor or dict nodes : tensor or dict
Node IDs to sample neighbors from. If it's a dict, it should contain only Node IDs to sample neighbors from. If it's a dict, it should contain only
one key-value pair to make this API consistent with dgl.sampling.sample_neighbors. one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.
etype_field : string
The field in g.edata storing the edge type.
fanout : int or dict[etype, int] fanout : int or dict[etype, int]
The number of edges to be sampled for each node per edge type. If an integer The number of edges to be sampled for each node per edge type. If an integer
is given, DGL assumes that the same fanout is applied to every edge type. is given, DGL assumes that the same fanout is applied to every edge type.
...@@ -625,25 +659,47 @@ def sample_etype_neighbors( ...@@ -625,25 +659,47 @@ def sample_etype_neighbors(
nodes = F.cat(homo_nids, 0) nodes = F.cat(homo_nids, 0)
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
if prob is not None:
# See NOTE 1
_prob = [
# NOTE (BarclayII)
# Currently DistGraph.edges[] does not accept canonical etype.
g.edges[etype].data[prob].kvstore_key
if prob in g.edges[etype].data
else ""
for etype in g.etypes
]
else:
_prob = None
return SamplingRequestEtype( return SamplingRequestEtype(
node_ids, node_ids,
etype_field,
fanout, fanout,
edge_dir=edge_dir, edge_dir=edge_dir,
prob=prob, prob=_prob,
replace=replace, replace=replace,
etype_sorted=etype_sorted, etype_sorted=etype_sorted,
) )
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
etype_offset = gpb.local_etype_offset
# See NOTE 1
if prob is None:
_prob = None
else:
_prob = [
g.edges[etype].data[prob].local_partition
if prob in g.edges[etype].data
else None
for etype in g.etypes
]
return _sample_etype_neighbors( return _sample_etype_neighbors(
local_g, local_g,
partition_book, partition_book,
local_nids, local_nids,
etype_field, etype_offset,
fanout, fanout,
edge_dir, edge_dir,
prob, _prob,
replace, replace,
etype_sorted=etype_sorted, etype_sorted=etype_sorted,
) )
...@@ -723,13 +779,28 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False): ...@@ -723,13 +779,28 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False):
nodes = list(nodes.values())[0] nodes = list(nodes.values())[0]
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
if prob is not None:
# See NOTE 1
_prob = g.edata[prob].kvstore_key
else:
_prob = None
return SamplingRequest( return SamplingRequest(
node_ids, fanout, edge_dir=edge_dir, prob=prob, replace=replace node_ids, fanout, edge_dir=edge_dir, prob=_prob, replace=replace
) )
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
# See NOTE 1
_prob = (
[g.edata[prob].local_partition] if prob is not None else None
)
return _sample_neighbors( return _sample_neighbors(
local_g, partition_book, local_nids, fanout, edge_dir, prob, replace local_g,
partition_book,
local_nids,
fanout,
edge_dir,
_prob,
replace,
) )
frontier = _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
......
...@@ -1394,6 +1394,17 @@ class KVClient(object): ...@@ -1394,6 +1394,17 @@ class KVClient(object):
total += res.num_local_nonzero total += res.num_local_nonzero
return total return total
@property
def data_store(self):
"""Return the local partition of the data storage.
Returns
-------
dict[str, Tensor]
The tensor storages of the local partition.
"""
return self._data_store
KVCLIENT = None KVCLIENT = None
def init_kvstore(ip_config, num_servers, role): def init_kvstore(ip_config, num_servers, role):
......
...@@ -103,6 +103,17 @@ class KVClient(object): ...@@ -103,6 +103,17 @@ class KVClient(object):
""" """
return F.count_nonzero(self._data[name]) return F.count_nonzero(self._data[name])
@property
def data_store(self):
"""Return the local partition of the data storage.
Returns
-------
dict[str, Tensor]
The tensor storages of the local partition.
"""
return self._data
def union(self, operand1_name, operand2_name, output_name): def union(self, operand1_name, operand2_name, output_name):
"""Compute the union of two mask arrays in the KVStore. """Compute the union of two mask arrays in the KVStore.
""" """
......
...@@ -14,9 +14,52 @@ __all__ = [ ...@@ -14,9 +14,52 @@ __all__ = [
'sample_neighbors_biased', 'sample_neighbors_biased',
'select_topk'] 'select_topk']
def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=None, def _prepare_edge_arrays(g, arg):
replace=False, copy_ndata=True, copy_edata=True, etype_sorted=False, """Converts the argument into a list of NDArrays.
_dist_training=False, output_device=None):
If the argument is already a list of array-like objects, directly do the
conversion.
If the argument is a string, converts g.edata[arg] into a list of NDArrays
ordered by the edge types.
"""
if isinstance(arg, list) and len(arg) > 0:
if isinstance(arg[0], nd.NDArray):
return arg
else:
# The list can have None as placeholders for empty arrays with
# undetermined data type.
dtype = None
ctx = None
result = []
for entry in arg:
if F.is_tensor(entry):
result.append(F.to_dgl_nd(entry))
dtype = F.dtype(entry)
ctx = F.context(entry)
else:
result.append(None)
result = [
F.to_dgl_nd(F.copy_to(F.tensor([], dtype=dtype), ctx))
if x is None else x
for x in result]
return result
elif arg is None:
return [nd.array([], ctx=nd.cpu())] * len(g.etypes)
else:
arrays = []
for etype in g.canonical_etypes:
if arg in g.edges[etype].data:
arrays.append(F.to_dgl_nd(g.edges[etype].data[arg]))
else:
arrays.append(nd.array([], ctx=nd.cpu()))
return arrays
def sample_etype_neighbors(
g, nodes, etype_offset, fanout, edge_dir='in', prob=None,
replace=False, copy_ndata=True, copy_edata=True, etype_sorted=False,
_dist_training=False, output_device=None):
"""Sample neighboring edges of the given nodes and return the induced subgraph. """Sample neighboring edges of the given nodes and return the induced subgraph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...@@ -35,25 +78,23 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -35,25 +78,23 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
This argument can take a single ID tensor or a dictionary of node types and ID tensors. This argument can take a single ID tensor or a dictionary of node types and ID tensors.
If a single tensor is given, the graph must only have one type of nodes. If a single tensor is given, the graph must only have one type of nodes.
etype_field : string etype_offset : list[int]
The field in g.edata storing the edge type. The offset of each edge type ID.
fanout : Tensor fanout : Tensor
The number of edges to be sampled for each node per edge type. Must be a The number of edges to be sampled for each node per edge type. Must be a
1D tensor with the number of elements same as the number of edge types. 1D tensor with the number of elements same as the number of edge types.
If -1 is given, all of the neighbors will be selected. If -1 is given, all of the neighbors with non-zero probability will be selected.
edge_dir : str, optional edge_dir : str, optional
Determines whether to sample inbound or outbound edges. Determines whether to sample inbound or outbound edges.
Can take either ``in`` for inbound edges or ``out`` for outbound edges. Can take either ``in`` for inbound edges or ``out`` for outbound edges.
prob : str, optional prob : list[Tensor], optional
Feature name used as the (unnormalized) probabilities associated with each The (unnormalized) probabilities associated with each neighboring edge of
neighboring edge of a node. The feature must have only one element for each a node.
edge.
The features must be non-negative floats, and the sum of the features of The features must be non-negative floats or boolean. Otherwise, the
inbound/outbound edges for every node must be positive (though they don't have result will be undefined.
to sum up to one). Otherwise, the result will be undefined.
replace : bool, optional replace : bool, optional
If True, sample with replacement. If True, sample with replacement.
copy_ndata: bool, optional copy_ndata: bool, optional
...@@ -94,9 +135,6 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -94,9 +135,6 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
""" """
if g.device != F.cpu(): if g.device != F.cpu():
raise DGLError("The graph should be in cpu.") raise DGLError("The graph should be in cpu.")
if etype_field not in g.edata:
raise DGLError("The graph should have {} in the edge data" \
"representing the edge type.".format(etype_field))
# (BarclayII) because the homogenized graph no longer contains the *name* of edge # (BarclayII) because the homogenized graph no longer contains the *name* of edge
# types, the fanout argument can no longer be a dict of etypes and ints, as opposed # types, the fanout argument can no longer be a dict of etypes and ints, as opposed
# to sample_neighbors. # to sample_neighbors.
...@@ -105,26 +143,19 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -105,26 +143,19 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
if isinstance(nodes, dict): if isinstance(nodes, dict):
assert len(nodes) == 1, "The input graph should not have node types" assert len(nodes) == 1, "The input graph should not have node types"
nodes = list(nodes.values())[0] nodes = list(nodes.values())[0]
nodes = utils.prepare_tensor(g, nodes, 'nodes') nodes = utils.prepare_tensor(g, nodes, 'nodes')
device = utils.context_of(nodes) device = utils.context_of(nodes)
nodes = F.to_dgl_nd(nodes) nodes = F.to_dgl_nd(nodes)
# treat etypes as int32, it is much cheaper than int64 # treat etypes as int32, it is much cheaper than int64
# TODO(xiangsx): int8 can be a better choice. # TODO(xiangsx): int8 can be a better choice.
etypes = F.to_dgl_nd(F.astype(g.edata[etype_field], ty=F.int32))
fanout = F.to_dgl_nd(fanout) fanout = F.to_dgl_nd(fanout)
if prob is None: prob_array = _prepare_edge_arrays(g, prob)
prob_array = nd.array([], ctx=nd.cpu())
elif isinstance(prob, nd.NDArray):
prob_array = prob
else:
if prob in g.edata:
prob_array = F.to_dgl_nd(g.edata[prob])
else:
prob_array = F.to_dgl_nd(F.tensor(prob, dtype=F.float32))
subgidx = _CAPI_DGLSampleNeighborsEType(g._graph, nodes, etypes, fanout, subgidx = _CAPI_DGLSampleNeighborsEType(
edge_dir, prob_array, replace, etype_sorted) g._graph, nodes, etype_offset, fanout, edge_dir, prob_array,
replace, etype_sorted)
induced_edges = subgidx.induced_edges induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes) ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
...@@ -149,9 +180,10 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -149,9 +180,10 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
DGLHeteroGraph.sample_etype_neighbors = utils.alias_func(sample_etype_neighbors) DGLHeteroGraph.sample_etype_neighbors = utils.alias_func(sample_etype_neighbors)
def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None,
copy_ndata=True, copy_edata=True, _dist_training=False, replace=False, copy_ndata=True, copy_edata=True,
exclude_edges=None, output_device=None): _dist_training=False, exclude_edges=None,
output_device=None):
"""Sample neighboring edges of the given nodes and return the induced subgraph. """Sample neighboring edges of the given nodes and return the induced subgraph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...@@ -181,7 +213,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -181,7 +213,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
every edge type. every edge type.
If -1 is given for a single edge type, all the neighboring edges with that edge If -1 is given for a single edge type, all the neighboring edges with that edge
type will be selected. type and non-zero probability will be selected.
edge_dir : str, optional edge_dir : str, optional
Determines whether to sample inbound or outbound edges. Determines whether to sample inbound or outbound edges.
...@@ -191,9 +223,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -191,9 +223,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
neighboring edge of a node. The feature must have only one element for each neighboring edge of a node. The feature must have only one element for each
edge. edge.
The features must be non-negative floats, and the sum of the features of The features must be non-negative floats or boolean. Otherwise, the result
inbound/outbound edges for every node must be positive (though they don't have will be undefined.
to sum up to one). Otherwise, the result will be undefined.
exclude_edges: tensor or dict exclude_edges: tensor or dict
Edge IDs to exclude during sampling neighbors for the seed nodes. Edge IDs to exclude during sampling neighbors for the seed nodes.
...@@ -290,20 +321,21 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -290,20 +321,21 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
""" """
if F.device_type(g.device) == 'cpu' and not g.is_pinned(): if F.device_type(g.device) == 'cpu' and not g.is_pinned():
frontier = _sample_neighbors( frontier = _sample_neighbors(
g, nodes, fanout, edge_dir=edge_dir, prob=prob, replace=replace, g, nodes, fanout, edge_dir=edge_dir, prob=prob,
copy_ndata=copy_ndata, copy_edata=copy_edata, exclude_edges=exclude_edges) replace=replace, copy_ndata=copy_ndata, copy_edata=copy_edata,
exclude_edges=exclude_edges)
else: else:
frontier = _sample_neighbors( frontier = _sample_neighbors(
g, nodes, fanout, edge_dir=edge_dir, prob=prob, replace=replace, g, nodes, fanout, edge_dir=edge_dir, prob=prob,
copy_ndata=copy_ndata, copy_edata=copy_edata) replace=replace, copy_ndata=copy_ndata, copy_edata=copy_edata)
if exclude_edges is not None: if exclude_edges is not None:
eid_excluder = EidExcluder(exclude_edges) eid_excluder = EidExcluder(exclude_edges)
frontier = eid_excluder(frontier) frontier = eid_excluder(frontier)
return frontier if output_device is None else frontier.to(output_device) return frontier if output_device is None else frontier.to(output_device)
def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None,
copy_ndata=True, copy_edata=True, _dist_training=False, replace=False, copy_ndata=True, copy_edata=True,
exclude_edges=None): _dist_training=False, exclude_edges=None):
if not isinstance(nodes, dict): if not isinstance(nodes, dict):
if len(g.ntypes) > 1: if len(g.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.") raise DGLError("Must specify node type when the graph is not homogeneous.")
...@@ -337,18 +369,7 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -337,18 +369,7 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
fanout_array[g.get_etype_id(etype)] = value fanout_array[g.get_etype_id(etype)] = value
fanout_array = F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64)) fanout_array = F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64))
if isinstance(prob, list) and len(prob) > 0 and \ prob_arrays = _prepare_edge_arrays(g, prob)
isinstance(prob[0], nd.NDArray):
prob_arrays = prob
elif prob is None:
prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes)
else:
prob_arrays = []
for etype in g.canonical_etypes:
if prob in g.edges[etype].data:
prob_arrays.append(F.to_dgl_nd(g.edges[etype].data[prob]))
else:
prob_arrays.append(nd.array([], ctx=nd.cpu()))
excluded_edges_all_t = [] excluded_edges_all_t = []
if exclude_edges is not None: if exclude_edges is not None:
...@@ -363,8 +384,9 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -363,8 +384,9 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
else: else:
excluded_edges_all_t.append(nd.array([], ctx=ctx)) excluded_edges_all_t.append(nd.array([], ctx=ctx))
subgidx = _CAPI_DGLSampleNeighbors(g._graph, nodes_all_types, fanout_array, subgidx = _CAPI_DGLSampleNeighbors(
edge_dir, prob_arrays, excluded_edges_all_t, replace) g._graph, nodes_all_types, fanout_array, edge_dir, prob_arrays,
excluded_edges_all_t, replace)
induced_edges = subgidx.induced_edges induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes) ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
...@@ -441,13 +463,12 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in', ...@@ -441,13 +463,12 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in',
fanout : int fanout : int
The number of edges to be sampled for each node on each edge type. The number of edges to be sampled for each node on each edge type.
If -1 is given, all the neighboring edges will be selected. If -1 is given, all the neighboring edges with non-zero probability will be selected.
bias : tensor or list bias : tensor or list
The (unnormalized) probabilities associated with each tag. Its length should be equal The (unnormalized) probabilities associated with each tag. Its length should be equal
to the number of tags. to the number of tags.
Entries of this array must be non-negative floats, and the sum of the entries must be Entries of this array must be non-negative floats. Otherwise, the result will be
positive (though they don't have to sum up to one). Otherwise, the result will be
undefined. undefined.
edge_dir : str, optional edge_dir : str, optional
Determines whether to sample inbound or outbound edges. Determines whether to sample inbound or outbound edges.
......
...@@ -558,7 +558,7 @@ COOMatrix CSRRowWiseSampling( ...@@ -558,7 +558,7 @@ COOMatrix CSRRowWiseSampling(
CHECK(!(prob_or_mask->dtype.bits == 8 && XPU == kDGLCUDA)) << CHECK(!(prob_or_mask->dtype.bits == 8 && XPU == kDGLCUDA)) <<
"GPU sampling with masks is currently not supported yet."; "GPU sampling with masks is currently not supported yet.";
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH( ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask->dtype, FloatType, "prob_or_maskability or mask", { prob_or_mask->dtype, FloatType, "probability or mask", {
ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>( ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob_or_mask, replace); mat, rows, num_samples, prob_or_mask, replace);
}); });
...@@ -568,18 +568,20 @@ COOMatrix CSRRowWiseSampling( ...@@ -568,18 +568,20 @@ COOMatrix CSRRowWiseSampling(
} }
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask,
bool etype_sorted) { bool replace, bool rowwise_etype_sorted) {
COOMatrix ret; COOMatrix ret;
CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", { ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", {
if (IsNullArray(prob)) { if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
ret = impl::CSRRowWisePerEtypeSamplingUniform<XPU, IdType>( ret = impl::CSRRowWisePerEtypeSamplingUniform<XPU, IdType>(
mat, rows, etypes, num_samples, replace, etype_sorted); mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted);
} else { } else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", { ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, FloatType>( prob_or_mask[0]->dtype, DType, "probability or mask", {
mat, rows, etypes, num_samples, prob, replace, etype_sorted); ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, DType>(
mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace, rowwise_etype_sorted);
}); });
} }
}); });
...@@ -814,7 +816,7 @@ COOMatrix COORowWiseSampling( ...@@ -814,7 +816,7 @@ COOMatrix COORowWiseSampling(
ret = impl::COORowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace); ret = impl::COORowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
} else { } else {
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH( ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask->dtype, DType, "prob_or_maskability or mask", { prob_or_mask->dtype, DType, "probability or mask", {
ret = impl::COORowWiseSampling<XPU, IdType, DType>( ret = impl::COORowWiseSampling<XPU, IdType, DType>(
mat, rows, num_samples, prob_or_mask, replace); mat, rows, num_samples, prob_or_mask, replace);
}); });
...@@ -824,18 +826,20 @@ COOMatrix COORowWiseSampling( ...@@ -824,18 +826,20 @@ COOMatrix COORowWiseSampling(
} }
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask,
bool etype_sorted) { bool replace) {
COOMatrix ret; COOMatrix ret;
CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", { ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", {
if (IsNullArray(prob)) { if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
ret = impl::COORowWisePerEtypeSamplingUniform<XPU, IdType>( ret = impl::COORowWisePerEtypeSamplingUniform<XPU, IdType>(
mat, rows, etypes, num_samples, replace, etype_sorted); mat, rows, eid2etype_offset, num_samples, replace);
} else { } else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", { ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
ret = impl::COORowWisePerEtypeSampling<XPU, IdType, FloatType>( prob_or_mask[0]->dtype, DType, "probability or mask", {
mat, rows, etypes, num_samples, prob, replace, etype_sorted); ret = impl::COORowWisePerEtypeSampling<XPU, IdType, DType>(
mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace);
}); });
} }
}); });
......
...@@ -46,6 +46,9 @@ DType IndexSelect(NDArray array, int64_t index); ...@@ -46,6 +46,9 @@ DType IndexSelect(NDArray array, int64_t index);
template <DGLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
IdArray NonZero(BoolArray bool_arr); IdArray NonZero(BoolArray bool_arr);
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(NDArray array);
template <DGLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits); std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);
...@@ -73,9 +76,6 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); ...@@ -73,9 +76,6 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero); IdArray CumSum(IdArray array, bool prepend_zero);
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(NDArray array);
// sparse arrays // sparse arrays
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
...@@ -165,11 +165,12 @@ COOMatrix CSRRowWiseSampling( ...@@ -165,11 +165,12 @@ COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace); CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes, CSRMatrix mat, IdArray rows,
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, const std::vector<int64_t>& eid2etype_offset,
bool etype_sorted); const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace, bool rowwise_etype_sorted);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform( COOMatrix CSRRowWiseSamplingUniform(
...@@ -177,8 +178,9 @@ COOMatrix CSRRowWiseSamplingUniform( ...@@ -177,8 +178,9 @@ COOMatrix CSRRowWiseSamplingUniform(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform( COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples, CSRMatrix mat, IdArray rows,
bool replace, bool etype_sorted); const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace, bool rowwise_etype_sorted);
// FloatType is the type of weight data. // FloatType is the type of weight data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
...@@ -274,10 +276,12 @@ COOMatrix COORowWiseSampling( ...@@ -274,10 +276,12 @@ COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace); COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix mat, IdArray rows,
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, bool etype_sorted); const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWiseSamplingUniform( COOMatrix COORowWiseSamplingUniform(
...@@ -285,8 +289,9 @@ COOMatrix COORowWiseSamplingUniform( ...@@ -285,8 +289,9 @@ COOMatrix COORowWiseSamplingUniform(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples, COOMatrix mat, IdArray rows,
bool replace, bool etype_sorted); const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace);
// FloatType is the type of weight data. // FloatType is the type of weight data.
template <DGLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename FloatType>
......
...@@ -83,13 +83,14 @@ using NumPicksFn = std::function<IdxType( ...@@ -83,13 +83,14 @@ using NumPicksFn = std::function<IdxType(
// \param cur_et The edge type. // \param cur_et The edge type.
// \param et_len Length of the range. // \param et_len Length of the range.
// \param et_idx A map from local idx to column id. // \param et_idx A map from local idx to column id.
// \param data Pointer of the data indices. // \param et_eid Edge-type-specific id array.
// \param eid Pointer of the homogenized edge id array.
// \param out_idx Picked indices in [et_offset, et_offset + et_len). // \param out_idx Picked indices in [et_offset, et_offset + et_len).
template <typename IdxType> template <typename IdxType>
using RangePickFn = std::function<void( using EtypeRangePickFn = std::function<void(
IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len, IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx, const IdxType* data, const std::vector<IdxType>& et_idx, const std::vector<IdxType>& et_eid,
IdxType* out_idx)>; const IdxType* eid, IdxType* out_idx)>;
// Template for picking non-zero values row-wise. The implementation utilizes // Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently. // OpenMP parallelization on rows because each row performs computation independently.
...@@ -208,20 +209,21 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -208,20 +209,21 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// Template for picking non-zero values row-wise. The implementation utilizes // Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently. // OpenMP parallelization on rows because each row performs computation independently.
template <typename IdxType> template <typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes, COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_picks, bool replace, const std::vector<int64_t>& num_picks, bool replace,
bool etype_sorted, RangePickFn<IdxType> pick_fn) { bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn,
const std::vector<NDArray>& prob_or_mask) {
using namespace aten; using namespace aten;
const IdxType* indptr = mat.indptr.Ptr<IdxType>(); const IdxType* indptr = mat.indptr.Ptr<IdxType>();
const IdxType* indices = mat.indices.Ptr<IdxType>(); const IdxType* indices = mat.indices.Ptr<IdxType>();
const IdxType* data = CSRHasData(mat)? mat.data.Ptr<IdxType>() : nullptr; const IdxType* eid = CSRHasData(mat)? mat.data.Ptr<IdxType>() : nullptr;
const IdxType* rows_data = rows.Ptr<IdxType>(); const IdxType* rows_data = rows.Ptr<IdxType>();
const int32_t* etype_data = etypes.Ptr<int32_t>();
const int64_t num_rows = rows->shape[0]; const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx; const auto& ctx = mat.indptr->ctx;
const int64_t num_etypes = num_picks.size(); const int64_t num_etypes = num_picks.size();
CHECK_EQ(etypes->dtype.bits / 8, sizeof(int32_t)) << "etypes must be int32"; const bool has_probs = (prob_or_mask.size() > 0);
std::vector<IdArray> picked_rows(rows->shape[0]); std::vector<IdArray> picked_rows(rows->shape[0]);
std::vector<IdArray> picked_cols(rows->shape[0]); std::vector<IdArray> picked_cols(rows->shape[0]);
std::vector<IdArray> picked_idxs(rows->shape[0]); std::vector<IdArray> picked_idxs(rows->shape[0]);
...@@ -260,13 +262,36 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes, ...@@ -260,13 +262,36 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx); IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx);
IdxType* cdata = cols.Ptr<IdxType>(); IdxType* cdata = cols.Ptr<IdxType>();
IdxType* idata = idx.Ptr<IdxType>(); IdxType* idata = idx.Ptr<IdxType>();
int64_t k = 0;
for (int64_t j = 0; j < len; ++j) { for (int64_t j = 0; j < len; ++j) {
cdata[j] = indices[off + j]; const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
idata[j] = data ? data[off + j] : off + j; auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \
homogenized_eid - eid2etype_offset[heterogenized_etype];
if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
// No probability array, select all
cdata[k] = indices[off + j];
idata[k] = homogenized_eid;
++k;
} else {
// Select the entries with non-zero probability
const NDArray& p = prob_or_mask[heterogenized_etype];
const DType* pdata = p.Ptr<DType>();
if (pdata[heterogenized_eid] > 0) {
cdata[k] = indices[off + j];
idata[k] = homogenized_eid;
++k;
}
}
} }
picked_rows[i] = rows;
picked_cols[i] = cols; picked_rows[i] = rows.CreateView({k}, rows->dtype);
picked_idxs[i] = idx; picked_cols[i] = cols.CreateView({k}, cols->dtype);
picked_idxs[i] = idx.CreateView({k}, idx->dtype);
} else { } else {
// need to do per edge type sample // need to do per edge type sample
std::vector<IdxType> rows; std::vector<IdxType> rows;
...@@ -275,15 +300,22 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes, ...@@ -275,15 +300,22 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
std::vector<IdxType> et(len); std::vector<IdxType> et(len);
std::vector<IdxType> et_idx(len); std::vector<IdxType> et_idx(len);
std::vector<IdxType> et_eid(len);
std::iota(et_idx.begin(), et_idx.end(), 0); std::iota(et_idx.begin(), et_idx.end(), 0);
for (int64_t j = 0; j < len; ++j) { for (int64_t j = 0; j < len; ++j) {
et[j] = data ? etype_data[data[off+j]] : etype_data[off+j]; const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \
homogenized_eid - eid2etype_offset[heterogenized_etype];
et[j] = heterogenized_etype;
et_eid[j] = heterogenized_eid;
} }
if (!etype_sorted) // the edge type is sorted, not need to sort it if (!rowwise_etype_sorted) // the edge type is sorted, not need to sort it
std::sort(et_idx.begin(), et_idx.end(), std::sort(et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];}); [&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];});
CHECK(et[et_idx[len - 1]] < num_etypes) << CHECK_LT(et[et_idx[len - 1]], num_etypes) << "etype values exceed the number of fanouts";
"etype values exceed the number of fanouts";
IdxType cur_et = et[et_idx[0]]; IdxType cur_et = et[et_idx[0]];
int64_t et_offset = 0; int64_t et_offset = 0;
...@@ -291,7 +323,7 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes, ...@@ -291,7 +323,7 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
for (int64_t j = 0; j < len; ++j) { for (int64_t j = 0; j < len; ++j) {
CHECK((j + 1 == len) || (et[et_idx[j]] <= et[et_idx[j + 1]])) CHECK((j + 1 == len) || (et[et_idx[j]] <= et[et_idx[j + 1]]))
<< "Edge type is not sorted. Please sort in advance or specify " << "Edge type is not sorted. Please sort in advance or specify "
"'etype_sorted' as false."; "'rowwise_etype_sorted' as false.";
if ((j + 1 == len) || cur_et != et[et_idx[j + 1]]) { if ((j + 1 == len) || cur_et != et[et_idx[j + 1]]) {
// 1 end of the current etype // 1 end of the current etype
// 2 end of the row // 2 end of the row
...@@ -300,29 +332,49 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes, ...@@ -300,29 +332,49 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, IdArray etypes,
(et_len <= num_picks[cur_et] && !replace)) { (et_len <= num_picks[cur_et] && !replace)) {
// fast path, select all // fast path, select all
for (int64_t k = 0; k < et_len; ++k) { for (int64_t k = 0; k < et_len; ++k) {
rows.push_back(rid); const IdxType eid_offset = off + et_idx[et_offset + k];
cols.push_back(indices[off+et_idx[et_offset+k]]); const IdxType homogenized_eid = eid ? eid[eid_offset] : eid_offset;
if (data) auto it = std::upper_bound(
idx.push_back(data[off+et_idx[et_offset+k]]); eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
else const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
idx.push_back(off+et_idx[et_offset+k]); const IdxType heterogenized_eid = \
homogenized_eid - eid2etype_offset[heterogenized_etype];
if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
// No probability, select all
rows.push_back(rid);
cols.push_back(indices[eid_offset]);
idx.push_back(homogenized_eid);
} else {
// Select the entries with non-zero probability
const NDArray& p = prob_or_mask[heterogenized_etype];
const DType* pdata = p.Ptr<DType>();
if (pdata[heterogenized_eid] > 0) {
rows.push_back(rid);
cols.push_back(indices[eid_offset]);
idx.push_back(homogenized_eid);
}
}
} }
} else { } else {
IdArray picked_idx = Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx); IdArray picked_idx = Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data); IdxType* picked_idata = picked_idx.Ptr<IdxType>();
// need call random pick // need call random pick
pick_fn(off, et_offset, cur_et, pick_fn(off, et_offset, cur_et,
et_len, et_idx, et_len, et_idx, et_eid,
data, picked_idata); eid, picked_idata);
for (int64_t k = 0; k < num_picks[cur_et]; ++k) { for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
const IdxType picked = picked_idata[k]; const IdxType picked = picked_idata[k];
if (picked == -1)
continue;
rows.push_back(rid); rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+picked]]); cols.push_back(indices[off+et_idx[et_offset+picked]]);
if (data) if (eid) {
idx.push_back(data[off+et_idx[et_offset+picked]]); idx.push_back(eid[off+et_idx[et_offset+picked]]);
else } else {
idx.push_back(off+et_idx[et_offset+picked]); idx.push_back(off+et_idx[et_offset+picked]);
}
} }
} }
...@@ -375,15 +427,17 @@ COOMatrix COORowWisePick(COOMatrix mat, IdArray rows, ...@@ -375,15 +427,17 @@ COOMatrix COORowWisePick(COOMatrix mat, IdArray rows,
// Template for picking non-zero values row-wise. The implementation first slices // Template for picking non-zero values row-wise. The implementation first slices
// out the corresponding rows and then converts it to CSR format. It then performs // out the corresponding rows and then converts it to CSR format. It then performs
// row-wise pick on the CSR matrix and rectifies the returned results. // row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType> template <typename IdxType, typename DType>
COOMatrix COORowWisePerEtypePick(COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix COORowWisePerEtypePick(
const std::vector<int64_t>& num_picks, bool replace, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
bool etype_sorted, RangePickFn<IdxType> pick_fn) { const std::vector<int64_t>& num_picks, bool replace,
EtypeRangePickFn<IdxType> pick_fn,
const std::vector<NDArray>& prob_or_mask) {
using namespace aten; using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows)); const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx); const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePerEtypePick<IdxType>( const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>(
csr, new_rows, etypes, num_picks, replace, etype_sorted, pick_fn); csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn, prob_or_mask);
return COOMatrix(mat.num_rows, mat.num_cols, return COOMatrix(mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col, picked.col,
......
...@@ -37,7 +37,8 @@ inline NumPicksFn<IdxType> GetSamplingNumPicksFn( ...@@ -37,7 +37,8 @@ inline NumPicksFn<IdxType> GetSamplingNumPicksFn(
const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>(); const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();
IdxType nnz = 0; IdxType nnz = 0;
for (IdxType i = off; i < off + len; ++i) { for (IdxType i = off; i < off + len; ++i) {
if (prob_or_mask_data[i] > 0) { const IdxType eid = data ? data[i] : i;
if (prob_or_mask_data[eid] > 0) {
++nnz; ++nnz;
} }
} }
...@@ -69,20 +70,21 @@ inline PickFn<IdxType> GetSamplingPickFn( ...@@ -69,20 +70,21 @@ inline PickFn<IdxType> GetSamplingPickFn(
} }
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
inline RangePickFn<IdxType> GetSamplingRangePickFn( inline EtypeRangePickFn<IdxType> GetSamplingRangePickFn(
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace) { const std::vector<int64_t>& num_samples,
RangePickFn<IdxType> pick_fn = [prob, num_samples, replace] const std::vector<FloatArray>& prob, bool replace) {
EtypeRangePickFn<IdxType> pick_fn = [prob, num_samples, replace]
(IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len, (IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx, const std::vector<IdxType> &et_idx,
const IdxType* data, IdxType* out_idx) { const std::vector<IdxType> &et_eid,
const FloatType* p_data = static_cast<FloatType*>(prob->data); const IdxType* eid, IdxType* out_idx) {
FloatArray probs = FloatArray::Empty({et_len}, prob->dtype, prob->ctx); const FloatArray& p = prob[cur_et];
FloatType* probs_data = static_cast<FloatType*>(probs->data); const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr<FloatType>();
FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx);
FloatType* probs_data = probs.Ptr<FloatType>();
for (int64_t j = 0; j < et_len; ++j) { for (int64_t j = 0; j < et_len; ++j) {
if (data) const IdxType cur_eid = et_eid[et_idx[et_offset + j]];
probs_data[j] = p_data[data[off+et_idx[et_offset+j]]]; probs_data[j] = p_data ? p_data[cur_eid] : static_cast<FloatType>(1.);
else
probs_data[j] = p_data[off+et_idx[et_offset+j]];
} }
RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>( RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
...@@ -124,11 +126,12 @@ inline PickFn<IdxType> GetSamplingUniformPickFn( ...@@ -124,11 +126,12 @@ inline PickFn<IdxType> GetSamplingUniformPickFn(
} }
template <typename IdxType> template <typename IdxType>
inline RangePickFn<IdxType> GetSamplingUniformRangePickFn( inline EtypeRangePickFn<IdxType> GetSamplingUniformRangePickFn(
const std::vector<int64_t>& num_samples, bool replace) { const std::vector<int64_t>& num_samples, bool replace) {
RangePickFn<IdxType> pick_fn = [num_samples, replace] EtypeRangePickFn<IdxType> pick_fn = [num_samples, replace]
(IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len, (IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx, const std::vector<IdxType> &et_idx,
const std::vector<IdxType> &et_eid,
const IdxType* data, IdxType* out_idx) { const IdxType* data, IdxType* out_idx) {
RandomEngine::ThreadLocal()->UniformChoice<IdxType>( RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
num_samples[cur_et], et_len, out_idx, replace); num_samples[cur_et], et_len, out_idx, replace);
...@@ -213,23 +216,45 @@ template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>( ...@@ -213,23 +216,45 @@ template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>( template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool); CSRMatrix, IdArray, int64_t, NDArray, bool);
template <DGLDeviceType XPU, typename IdxType, typename FloatType> template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes, COOMatrix CSRRowWisePerEtypeSampling(
const std::vector<int64_t>& num_samples, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
FloatArray prob, bool replace, bool etype_sorted) { const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask,
CHECK(prob.defined()); bool replace, bool rowwise_etype_sorted) {
auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace); CHECK(prob_or_mask.size() == num_samples.size()) <<
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); "the number of probability tensors does not match the number of edge types.";
for (auto& p : prob_or_mask)
CHECK(p.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
return CSRRowWisePerEtypePick<IdxType, DType>(
mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted, pick_fn,
prob_or_mask);
} }
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows, COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
...@@ -247,17 +272,20 @@ template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>( ...@@ -247,17 +272,20 @@ template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, int64_t, bool); CSRMatrix, IdArray, int64_t, bool);
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes, COOMatrix CSRRowWisePerEtypeSamplingUniform(
const std::vector<int64_t>& num_samples, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
bool replace, bool etype_sorted) { const std::vector<int64_t>& num_samples, bool replace, bool rowwise_etype_sorted) {
auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace); auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); return CSRRowWisePerEtypePick<IdxType, float>(
mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted, pick_fn, {});
} }
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>( template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool); CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool,
bool);
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>( template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool); CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool,
bool);
template <DGLDeviceType XPU, typename IdxType, typename FloatType> template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased( COOMatrix CSRRowWiseSamplingBiased(
...@@ -322,23 +350,44 @@ template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, uint8_t>( ...@@ -322,23 +350,44 @@ template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, uint8_t>(
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>( template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>(
COOMatrix, IdArray, int64_t, NDArray, bool); COOMatrix, IdArray, int64_t, NDArray, bool);
template <DGLDeviceType XPU, typename IdxType, typename FloatType> template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix COORowWisePerEtypeSampling(
const std::vector<int64_t>& num_samples, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
FloatArray prob, bool replace, bool etype_sorted) { const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask,
CHECK(prob.defined()); bool replace) {
auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace); CHECK(prob_or_mask.size() == num_samples.size()) <<
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); "the number of probability tensors do not match the number of edge types.";
for (auto& p : prob_or_mask)
CHECK(p.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
return COORowWisePerEtypePick<IdxType, DType>(
mat, rows, eid2etype_offset, num_samples, replace, pick_fn, prob_or_mask);
} }
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows, COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
...@@ -356,17 +405,18 @@ template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>( ...@@ -356,17 +405,18 @@ template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
COOMatrix, IdArray, int64_t, bool); COOMatrix, IdArray, int64_t, bool);
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix COORowWisePerEtypeSamplingUniform(
const std::vector<int64_t>& num_samples, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
bool replace, bool etype_sorted) { const std::vector<int64_t>& num_samples, bool replace) {
auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace); auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); return COORowWisePerEtypePick<IdxType, float>(
mat, rows, eid2etype_offset, num_samples, replace, pick_fn, {});
} }
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>( template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool); COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool);
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>( template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool); COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -77,7 +77,7 @@ HeteroSubgraph SampleNeighbors( ...@@ -77,7 +77,7 @@ HeteroSubgraph SampleNeighbors(
CHECK_EQ(fanouts.size(), hg->NumEdgeTypes()) CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
<< "Number of fanout values must match the number of edge types."; << "Number of fanout values must match the number of edge types.";
CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes()) CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes())
<< "Number of prob_or_maskability tensors must match the number of edge types."; << "Number of probability tensors must match the number of edge types.";
DGLContext ctx = aten::GetContextOf(nodes); DGLContext ctx = aten::GetContextOf(nodes);
...@@ -149,12 +149,12 @@ HeteroSubgraph SampleNeighbors( ...@@ -149,12 +149,12 @@ HeteroSubgraph SampleNeighbors(
HeteroSubgraph SampleNeighborsEType( HeteroSubgraph SampleNeighborsEType(
const HeteroGraphPtr hg, const HeteroGraphPtr hg,
const IdArray nodes, const IdArray nodes,
const IdArray etypes, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& fanouts, const std::vector<int64_t>& fanouts,
EdgeDir dir, EdgeDir dir,
const IdArray prob, const std::vector<FloatArray>& prob,
bool replace, bool replace,
bool etype_sorted) { bool rowwise_etype_sorted) {
CHECK_EQ(1, hg->NumVertexTypes()) CHECK_EQ(1, hg->NumVertexTypes())
<< "SampleNeighborsEType only work with homogeneous graph"; << "SampleNeighborsEType only work with homogeneous graph";
...@@ -183,43 +183,34 @@ HeteroSubgraph SampleNeighborsEType( ...@@ -183,43 +183,34 @@ HeteroSubgraph SampleNeighborsEType(
hg->NumVertices(dst_vtype), hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context()); hg->DataType(), hg->Context());
induced_edges[etype] = aten::NullArray(); induced_edges[etype] = aten::NullArray();
} else if (same_fanout && fanout_value == -1) {
const auto &earr = (dir == EdgeDir::kOut) ?
hg->OutEdges(etype, nodes) :
hg->InEdges(etype, nodes);
subrels[etype] = UnitGraph::CreateFromCOO(
1,
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id;
} else { } else {
COOMatrix sampled_coo;
// sample from graph // sample from graph
// the edge type is stored in etypes // the edge type is stored in etypes
auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE; auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt); auto avail_fmt = hg->SelectFormat(etype, req_fmt);
COOMatrix sampled_coo;
switch (avail_fmt) { switch (avail_fmt) {
case SparseFormat::kCOO: case SparseFormat::kCOO:
if (dir == EdgeDir::kIn) { if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling( sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling(
aten::COOTranspose(hg->GetCOOMatrix(etype)), aten::COOTranspose(hg->GetCOOMatrix(etype)),
nodes, etypes, fanouts, prob, replace)); nodes, eid2etype_offset, fanouts, prob, replace));
} else { } else {
sampled_coo = aten::COORowWisePerEtypeSampling( sampled_coo = aten::COORowWisePerEtypeSampling(
hg->GetCOOMatrix(etype), nodes, etypes, fanouts, prob, replace, etype_sorted); hg->GetCOOMatrix(etype), nodes, eid2etype_offset, fanouts, prob, replace);
} }
break; break;
case SparseFormat::kCSR: case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix."; CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling( sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSRMatrix(etype), nodes, etypes, fanouts, prob, replace, etype_sorted); hg->GetCSRMatrix(etype), nodes, eid2etype_offset,
fanouts, prob, replace, rowwise_etype_sorted);
break; break;
case SparseFormat::kCSC: case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix."; CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling( sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSCMatrix(etype), nodes, etypes, fanouts, prob, replace, etype_sorted); hg->GetCSCMatrix(etype), nodes, eid2etype_offset,
fanouts, prob, replace, rowwise_etype_sorted);
sampled_coo = aten::COOTranspose(sampled_coo); sampled_coo = aten::COOTranspose(sampled_coo);
break; break;
default: default:
...@@ -386,12 +377,12 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType") ...@@ -386,12 +377,12 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
IdArray nodes = args[1]; IdArray nodes = args[1];
IdArray etypes = args[2]; const std::vector<int64_t>& eid2etype_offset = ListValueToVector<int64_t>(args[2]);
IdArray fanout = args[3]; IdArray fanout = args[3];
const std::string dir_str = args[4]; const std::string dir_str = args[4];
IdArray prob = args[5]; const auto& prob = ListValueToVector<FloatArray>(args[5]);
const bool replace = args[6]; const bool replace = args[6];
const bool etype_sorted = args[7]; const bool rowwise_etype_sorted = args[7];
CHECK(dir_str == "in" || dir_str == "out") CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\"."; << "Invalid edge direction. Must be \"in\" or \"out\".";
...@@ -401,8 +392,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType") ...@@ -401,8 +392,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph); std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsEType( *subg = sampling::SampleNeighborsEType(
hg.sptr(), nodes, etypes, fanout_vec, dir, prob, replace, etype_sorted); hg.sptr(), nodes, eid2etype_offset, fanout_vec, dir, prob, replace, rowwise_etype_sorted);
*rv = HeteroSubgraphRef(subg); *rv = HeteroSubgraphRef(subg);
}); });
......
...@@ -742,6 +742,11 @@ def create_etype_test_graph(num_nodes, num_edges_per_node, rare_cnt): ...@@ -742,6 +742,11 @@ def create_etype_test_graph(num_nodes, num_edges_per_node, rare_cnt):
("v2", "e_minor", "u") : (minor_src, minor_dst), ("v2", "e_minor", "u") : (minor_src, minor_dst),
("v2", "most_zero", "u") : (most_zero_src, most_zero_dst), ("v2", "most_zero", "u") : (most_zero_src, most_zero_dst),
("u", "e_minor_rev", "v2") : (minor_dst, minor_src)}) ("u", "e_minor_rev", "v2") : (minor_dst, minor_src)})
for etype in g.etypes:
prob = np.random.rand(g.num_edges(etype))
prob[prob > 0.2] = 0
g.edges[etype].data['p'] = F.zerocopy_from_numpy(prob)
g.edges[etype].data['mask'] = F.zerocopy_from_numpy(prob != 0)
return g return g
...@@ -835,6 +840,7 @@ def test_sample_neighbors_biased_bipartite(): ...@@ -835,6 +840,7 @@ def test_sample_neighbors_biased_bipartite():
check_num(subg.edges()[1], tag) check_num(subg.edges()[1], tag)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
@unittest.skipIf(F.backend_name == 'mxnet', reason='MXNet has problem converting bool arrays')
@pytest.mark.parametrize('format_', ['coo', 'csr', 'csc']) @pytest.mark.parametrize('format_', ['coo', 'csr', 'csc'])
@pytest.mark.parametrize('direction', ['in', 'out']) @pytest.mark.parametrize('direction', ['in', 'out'])
@pytest.mark.parametrize('replace', [False, True]) @pytest.mark.parametrize('replace', [False, True])
...@@ -842,38 +848,59 @@ def test_sample_neighbors_etype_homogeneous(format_, direction, replace): ...@@ -842,38 +848,59 @@ def test_sample_neighbors_etype_homogeneous(format_, direction, replace):
num_nodes = 100 num_nodes = 100
rare_cnt = 4 rare_cnt = 4
g = create_etype_test_graph(100, 30, rare_cnt) g = create_etype_test_graph(100, 30, rare_cnt)
h_g = dgl.to_homogeneous(g) h_g = dgl.to_homogeneous(g, edata=['p', 'mask'])
h_g_etype = F.asnumpy(h_g.edata[dgl.ETYPE])
h_g_offset = np.cumsum(np.insert(np.bincount(h_g_etype), 0, 0)).tolist()
sg = g.edge_subgraph(g.edata['mask'], relabel_nodes=False)
h_sg = h_g.edge_subgraph(h_g.edata['mask'], relabel_nodes=False)
h_sg_etype = F.asnumpy(h_sg.edata[dgl.ETYPE])
h_sg_offset = np.cumsum(np.insert(np.bincount(h_sg_etype), 0, 0)).tolist()
seed_ntype = g.get_ntype_id("u") seed_ntype = g.get_ntype_id("u")
seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype) seeds = F.nonzero_1d(h_g.ndata[dgl.NTYPE] == seed_ntype)
fanouts = F.tensor([6, 5, 4, 3, 2], dtype=F.int64) fanouts = F.tensor([6, 5, 4, 3, 2], dtype=F.int64)
def check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction): def check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction):
src, dst = subg.edges() src, dst = subg.edges()
num_etypes = F.asnumpy(h_g.edata[dgl.ETYPE]).max() all_etype_array = F.asnumpy(h_g.edata[dgl.ETYPE])
num_etypes = all_etype_array.max() + 1
etype_array = F.asnumpy(subg.edata[dgl.ETYPE]) etype_array = F.asnumpy(subg.edata[dgl.ETYPE])
src = F.asnumpy(src) src = F.asnumpy(src)
dst = F.asnumpy(dst) dst = F.asnumpy(dst)
fanouts = F.asnumpy(fanouts) fanouts = F.asnumpy(fanouts)
all_etype_array = F.asnumpy(h_g.edata[dgl.ETYPE])
all_src = F.asnumpy(all_src) all_src = F.asnumpy(all_src)
all_dst = F.asnumpy(all_dst) all_dst = F.asnumpy(all_dst)
src_per_etype = [] src_per_etype = []
dst_per_etype = [] dst_per_etype = []
all_src_per_etype = []
all_dst_per_etype = []
for etype in range(num_etypes): for etype in range(num_etypes):
src_per_etype.append(src[etype_array == etype]) src_per_etype.append(src[etype_array == etype])
dst_per_etype.append(dst[etype_array == etype]) dst_per_etype.append(dst[etype_array == etype])
all_src_per_etype.append(all_src[all_etype_array == etype])
all_dst_per_etype.append(all_dst[all_etype_array == etype])
if replace: if replace:
if direction == 'in': if direction == 'in':
in_degree_per_etype = [np.bincount(d) for d in dst_per_etype] in_degree_per_etype = [np.bincount(d) for d in dst_per_etype]
for in_degree, fanout in zip(in_degree_per_etype, fanouts): for etype in range(len(fanouts)):
assert np.all(in_degree == fanout) in_degree = in_degree_per_etype[etype]
fanout = fanouts[etype]
ans = np.zeros_like(in_degree)
if len(in_degree) > 0:
ans[all_dst_per_etype[etype]] = fanout
assert np.all(in_degree == ans)
else: else:
out_degree_per_etype = [np.bincount(s) for s in src_per_etype] out_degree_per_etype = [np.bincount(s) for s in src_per_etype]
for out_degree, fanout in zip(out_degree_per_etype, fanouts): for etype in range(len(fanouts)):
assert np.all(out_degree == fanout) out_degree = out_degree_per_etype[etype]
fanout = fanouts[etype]
ans = np.zeros_like(out_degree)
if len(out_degree) > 0:
ans[all_src_per_etype[etype]] = fanout
assert np.all(out_degree == ans)
else: else:
if direction == 'in': if direction == 'in':
for v in set(dst): for v in set(dst):
...@@ -897,16 +924,31 @@ def test_sample_neighbors_etype_homogeneous(format_, direction, replace): ...@@ -897,16 +924,31 @@ def test_sample_neighbors_etype_homogeneous(format_, direction, replace):
assert (len(v_etype) == fanouts[etype]) or (v_etype == all_v_etype) assert (len(v_etype) == fanouts[etype]) or (v_etype == all_v_etype)
all_src, all_dst = h_g.edges() all_src, all_dst = h_g.edges()
all_sub_src, all_sub_dst = h_sg.edges()
h_g = h_g.formats(format_) h_g = h_g.formats(format_)
if (direction, format_) in [('in', 'csr'), ('out', 'csc')]: if (direction, format_) in [('in', 'csr'), ('out', 'csc')]:
h_g = h_g.formats(['csc', 'csr', 'coo']) h_g = h_g.formats(['csc', 'csr', 'coo'])
for _ in range(5): for _ in range(5):
subg = dgl.sampling.sample_etype_neighbors( subg = dgl.sampling.sample_etype_neighbors(
h_g, seeds, dgl.ETYPE, fanouts, replace=replace, edge_dir=direction) h_g, seeds, h_g_offset, fanouts, replace=replace,
edge_dir=direction)
check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction) check_num(h_g, all_src, all_dst, subg, replace, fanouts, direction)
p = [g.edges[etype].data['p'] for etype in g.etypes]
subg = dgl.sampling.sample_etype_neighbors(
h_g, seeds, h_g_offset, fanouts, replace=replace,
edge_dir=direction, prob=p)
check_num(h_sg, all_sub_src, all_sub_dst, subg, replace, fanouts, direction)
p = [g.edges[etype].data['mask'] for etype in g.etypes]
subg = dgl.sampling.sample_etype_neighbors(
h_g, seeds, h_g_offset, fanouts, replace=replace,
edge_dir=direction, prob=p)
check_num(h_sg, all_sub_src, all_sub_dst, subg, replace, fanouts, direction)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
@unittest.skipIf(F.backend_name == 'mxnet', reason='MXNet has problem converting bool arrays')
@pytest.mark.parametrize('format_', ['csr', 'csc']) @pytest.mark.parametrize('format_', ['csr', 'csc'])
@pytest.mark.parametrize('direction', ['in', 'out']) @pytest.mark.parametrize('direction', ['in', 'out'])
def test_sample_neighbors_etype_sorted_homogeneous(format_, direction): def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
...@@ -919,17 +961,16 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction): ...@@ -919,17 +961,16 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
h_g = h_g.formats(format_) h_g = h_g.formats(format_)
if (direction, format_) in [('in', 'csr'), ('out', 'csc')]: if (direction, format_) in [('in', 'csr'), ('out', 'csc')]:
h_g = h_g.formats(['csc', 'csr', 'coo']) h_g = h_g.formats(['csc', 'csr', 'coo'])
orig_etype = F.asnumpy(h_g.edata[dgl.ETYPE])
h_g.edata[dgl.ETYPE] = F.tensor(
np.sort(orig_etype)[::-1].tolist(), dtype=F.int64)
try: if direction == 'in':
dgl.sampling.sample_etype_neighbors( h_g = dgl.sort_csc_by_tag(h_g, h_g.edata[dgl.ETYPE], tag_type='edge')
h_g, seeds, dgl.ETYPE, fanouts, edge_dir=direction, etype_sorted=True) else:
fail = False h_g = dgl.sort_csr_by_tag(h_g, h_g.edata[dgl.ETYPE], tag_type='edge')
except dgl.DGLError: # shuffle
fail = True h_g_etype = F.asnumpy(h_g.edata[dgl.ETYPE])
assert fail h_g_offset = np.cumsum(np.insert(np.bincount(h_g_etype), 0, 0)).tolist()
sg = dgl.sampling.sample_etype_neighbors(
h_g, seeds, h_g_offset, fanouts, edge_dir=direction, etype_sorted=True)
@pytest.mark.parametrize('dtype', ['int32', 'int64']) @pytest.mark.parametrize('dtype', ['int32', 'int64'])
def test_sample_neighbors_exclude_edges_heteroG(dtype): def test_sample_neighbors_exclude_edges_heteroG(dtype):
...@@ -1064,6 +1105,8 @@ if __name__ == '__main__': ...@@ -1064,6 +1105,8 @@ if __name__ == '__main__':
test_sample_neighbors_mask() test_sample_neighbors_mask()
for args in product(['coo', 'csr', 'csc'], ['in', 'out'], [False, True]): for args in product(['coo', 'csr', 'csc'], ['in', 'out'], [False, True]):
test_sample_neighbors_etype_homogeneous(*args) test_sample_neighbors_etype_homogeneous(*args)
for args in product(['csr', 'csc'], ['in', 'out']):
test_sample_neighbors_etype_sorted_homogeneous(*args)
test_non_uniform_random_walk(False) test_non_uniform_random_walk(False)
test_uniform_random_walk(False) test_uniform_random_walk(False)
test_pack_traces() test_pack_traces()
......
This diff is collapsed.
...@@ -279,4 +279,4 @@ TEST(RandomTest, TestBiasedChoice) { ...@@ -279,4 +279,4 @@ TEST(RandomTest, TestBiasedChoice) {
_TestBiasedChoice<int64_t, float>(re); _TestBiasedChoice<int64_t, float>(re);
_TestBiasedChoice<int32_t, double>(re); _TestBiasedChoice<int32_t, double>(re);
_TestBiasedChoice<int64_t, double>(re); _TestBiasedChoice<int64_t, double>(re);
} }
\ No newline at end of file
import dgl import dgl
import unittest import unittest
import os import os
import traceback
from dgl.data import CitationGraphDataset from dgl.data import CitationGraphDataset
from dgl.data import WN18Dataset from dgl.data import WN18Dataset
from dgl.distributed import sample_neighbors, sample_etype_neighbors from dgl.distributed import sample_neighbors, sample_etype_neighbors
...@@ -35,7 +36,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem): ...@@ -35,7 +36,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
try: try:
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
except Exception as e: except Exception as e:
print(e) print(traceback.format_exc())
sampled_graph = None sampled_graph = None
dgl.distributed.exit_client() dgl.distributed.exit_client()
return sampled_graph return sampled_graph
...@@ -69,7 +70,7 @@ def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None): ...@@ -69,7 +70,7 @@ def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
try: try:
u, v = dist_graph.find_edges(eids, etype=etype) u, v = dist_graph.find_edges(eids, etype=etype)
except Exception as e: except Exception as e:
print(e) print(traceback.format_exc())
u, v = None, None u, v = None, None
dgl.distributed.exit_client() dgl.distributed.exit_client()
return u, v return u, v
...@@ -86,7 +87,7 @@ def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None): ...@@ -86,7 +87,7 @@ def start_get_degrees_client(rank, tmpdir, disable_shared_mem, nids=None):
out_deg = dist_graph.out_degrees(nids) out_deg = dist_graph.out_degrees(nids)
all_out_deg = dist_graph.out_degrees() all_out_deg = dist_graph.out_degrees()
except Exception as e: except Exception as e:
print(e) print(traceback.format_exc())
in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None in_deg, out_deg, all_in_deg, all_out_deg = None, None, None, None
dgl.distributed.exit_client() dgl.distributed.exit_client()
return in_deg, out_deg, all_in_deg, all_out_deg return in_deg, out_deg, all_in_deg, all_out_deg
...@@ -329,7 +330,7 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes): ...@@ -329,7 +330,7 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem, nodes):
block = dgl.to_block(sampled_graph, nodes) block = dgl.to_block(sampled_graph, nodes)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
except Exception as e: except Exception as e:
print(e) print(traceback.format_exc())
block = None block = None
dgl.distributed.exit_client() dgl.distributed.exit_client()
return block, gpb return block, gpb
...@@ -359,11 +360,12 @@ def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3, ...@@ -359,11 +360,12 @@ def start_hetero_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3,
if gpb is None: if gpb is None:
gpb = dist_graph.get_partition_book() gpb = dist_graph.get_partition_book()
try: try:
sampled_graph = sample_etype_neighbors(dist_graph, nodes, dgl.ETYPE, fanout, etype_sorted=etype_sorted) sampled_graph = sample_etype_neighbors(
dist_graph, nodes, fanout, etype_sorted=etype_sorted)
block = dgl.to_block(sampled_graph, nodes) block = dgl.to_block(sampled_graph, nodes)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
except Exception as e: except Exception as e:
print(e) print(traceback.format_exc())
block = None block = None
dgl.distributed.exit_client() dgl.distributed.exit_client()
return block, gpb return block, gpb
...@@ -581,8 +583,7 @@ def start_bipartite_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout ...@@ -581,8 +583,7 @@ def start_bipartite_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout
if gpb is None: if gpb is None:
gpb = dist_graph.get_partition_book() gpb = dist_graph.get_partition_book()
sampled_graph = sample_etype_neighbors( sampled_graph = sample_etype_neighbors(dist_graph, nodes, fanout)
dist_graph, nodes, dgl.ETYPE, fanout)
block = dgl.to_block(sampled_graph, nodes) block = dgl.to_block(sampled_graph, nodes)
if sampled_graph.num_edges() > 0: if sampled_graph.num_edges() > 0:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
...@@ -783,6 +784,10 @@ def test_rpc_sampling_shuffle(num_server): ...@@ -783,6 +784,10 @@ def test_rpc_sampling_shuffle(num_server):
def check_standalone_sampling(tmpdir, reshuffle): def check_standalone_sampling(tmpdir, reshuffle):
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
prob = np.maximum(np.random.randn(g.num_edges()), 0)
mask = (prob > 0)
g.edata['prob'] = F.tensor(prob)
g.edata['mask'] = F.tensor(mask)
num_parts = 1 num_parts = 1
num_hops = 1 num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir, partition_graph(g, 'test_sampling', num_parts, tmpdir,
...@@ -799,10 +804,24 @@ def check_standalone_sampling(tmpdir, reshuffle): ...@@ -799,10 +804,24 @@ def check_standalone_sampling(tmpdir, reshuffle):
eids = g.edge_ids(src, dst) eids = g.edge_ids(src, dst)
assert np.array_equal( assert np.array_equal(
F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)) F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob='mask')
eid = F.asnumpy(sampled_graph.edata[dgl.EID])
assert mask[eid].all()
sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob='prob')
eid = F.asnumpy(sampled_graph.edata[dgl.EID])
assert (prob[eid] > 0).all()
dgl.distributed.exit_client() dgl.distributed.exit_client()
def check_standalone_etype_sampling(tmpdir, reshuffle): def check_standalone_etype_sampling(tmpdir, reshuffle):
hg = CitationGraphDataset('cora')[0] hg = CitationGraphDataset('cora')[0]
prob = np.maximum(np.random.randn(hg.num_edges()), 0)
mask = (prob > 0)
hg.edata['prob'] = F.tensor(prob)
hg.edata['mask'] = F.tensor(mask)
num_parts = 1 num_parts = 1
num_hops = 1 num_hops = 1
...@@ -811,7 +830,7 @@ def check_standalone_etype_sampling(tmpdir, reshuffle): ...@@ -811,7 +830,7 @@ def check_standalone_etype_sampling(tmpdir, reshuffle):
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", part_config=tmpdir / 'test_sampling.json') dist_graph = DistGraph("test_sampling", part_config=tmpdir / 'test_sampling.json')
sampled_graph = sample_etype_neighbors(dist_graph, [0, 10, 99, 66, 1023], dgl.ETYPE, 3) sampled_graph = sample_etype_neighbors(dist_graph, [0, 10, 99, 66, 1023], 3)
src, dst = sampled_graph.edges() src, dst = sampled_graph.edges()
assert sampled_graph.number_of_nodes() == hg.number_of_nodes() assert sampled_graph.number_of_nodes() == hg.number_of_nodes()
...@@ -819,6 +838,16 @@ def check_standalone_etype_sampling(tmpdir, reshuffle): ...@@ -819,6 +838,16 @@ def check_standalone_etype_sampling(tmpdir, reshuffle):
eids = hg.edge_ids(src, dst) eids = hg.edge_ids(src, dst)
assert np.array_equal( assert np.array_equal(
F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)) F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
sampled_graph = sample_etype_neighbors(
dist_graph, [0, 10, 99, 66, 1023], 3, prob='mask')
eid = F.asnumpy(sampled_graph.edata[dgl.EID])
assert mask[eid].all()
sampled_graph = sample_etype_neighbors(
dist_graph, [0, 10, 99, 66, 1023], 3, prob='prob')
eid = F.asnumpy(sampled_graph.edata[dgl.EID])
assert (prob[eid] > 0).all()
dgl.distributed.exit_client() dgl.distributed.exit_client()
def check_standalone_etype_sampling_heterograph(tmpdir, reshuffle): def check_standalone_etype_sampling_heterograph(tmpdir, reshuffle):
...@@ -834,7 +863,8 @@ def check_standalone_etype_sampling_heterograph(tmpdir, reshuffle): ...@@ -834,7 +863,8 @@ def check_standalone_etype_sampling_heterograph(tmpdir, reshuffle):
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_hetero_sampling", part_config=tmpdir / 'test_hetero_sampling.json') dist_graph = DistGraph("test_hetero_sampling", part_config=tmpdir / 'test_hetero_sampling.json')
sampled_graph = sample_etype_neighbors(dist_graph, [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701], dgl.ETYPE, 1) sampled_graph = sample_etype_neighbors(
dist_graph, [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701], 1)
src, dst = sampled_graph.edges(etype=('paper', 'cite', 'paper')) src, dst = sampled_graph.edges(etype=('paper', 'cite', 'paper'))
assert len(src) == 10 assert len(src) == 10
src, dst = sampled_graph.edges(etype=('paper', 'cite-by', 'paper')) src, dst = sampled_graph.edges(etype=('paper', 'cite-by', 'paper'))
...@@ -861,7 +891,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes): ...@@ -861,7 +891,7 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
try: try:
sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes) sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
except Exception as e: except Exception as e:
print(e) print(traceback.format_exc())
sampled_graph = None sampled_graph = None
dgl.distributed.exit_client() dgl.distributed.exit_client()
return sampled_graph return sampled_graph
...@@ -925,7 +955,6 @@ def test_standalone_etype_sampling(): ...@@ -925,7 +955,6 @@ def test_standalone_etype_sampling():
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
check_standalone_etype_sampling(Path(tmpdirname), True) check_standalone_etype_sampling(Path(tmpdirname), True)
check_standalone_etype_sampling(Path(tmpdirname), False)
if __name__ == "__main__": if __name__ == "__main__":
import tempfile import tempfile
......
...@@ -71,14 +71,18 @@ def test_saint(num_workers, mode): ...@@ -71,14 +71,18 @@ def test_saint(num_workers, mode):
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('mode', ['cpu', 'uva_cuda_indices', 'uva_cpu_indices', 'pure_gpu']) @pytest.mark.parametrize('mode', ['cpu', 'uva_cuda_indices', 'uva_cpu_indices', 'pure_gpu'])
@pytest.mark.parametrize('use_ddp', [False, True]) @pytest.mark.parametrize('use_ddp', [False, True])
def test_neighbor_nonuniform(idtype, mode, use_ddp): @pytest.mark.parametrize('use_mask', [False, True])
def test_neighbor_nonuniform(idtype, mode, use_ddp, use_mask):
if mode != 'cpu' and F.ctx() == F.cpu(): if mode != 'cpu' and F.ctx() == F.cpu():
pytest.skip('UVA and GPU sampling require a GPU.') pytest.skip('UVA and GPU sampling require a GPU.')
if mode != 'cpu' and use_mask:
pytest.skip('Masked sampling only works on CPU.')
if use_ddp: if use_ddp:
dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl', dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
'tcp://127.0.0.1:12347', world_size=1, rank=0) 'tcp://127.0.0.1:12347', world_size=1, rank=0)
g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(idtype) g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(idtype)
g.edata['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0]) g.edata['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
g.edata['mask'] = (g.edata['p'] != 0)
if mode in ('cpu', 'uva_cpu_indices'): if mode in ('cpu', 'uva_cpu_indices'):
indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu()) indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu())
else: else:
...@@ -87,7 +91,12 @@ def test_neighbor_nonuniform(idtype, mode, use_ddp): ...@@ -87,7 +91,12 @@ def test_neighbor_nonuniform(idtype, mode, use_ddp):
g = g.to(F.cuda()) g = g.to(F.cuda())
use_uva = mode.startswith('uva') use_uva = mode.startswith('uva')
sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob='p') if use_mask:
prob, mask = None, 'mask'
else:
prob, mask = 'p', None
sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob=prob, mask=mask)
for num_workers in [0, 1, 2] if mode == 'cpu' else [0]: for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g, indices, sampler, g, indices, sampler,
...@@ -108,7 +117,9 @@ def test_neighbor_nonuniform(idtype, mode, use_ddp): ...@@ -108,7 +117,9 @@ def test_neighbor_nonuniform(idtype, mode, use_ddp):
('C', 'CA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]), ('C', 'CA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
}).astype(idtype) }).astype(idtype)
g.edges['BA'].data['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0]) g.edges['BA'].data['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
g.edges['BA'].data['mask'] = (g.edges['BA'].data['p'] != 0)
g.edges['CA'].data['p'] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1]) g.edges['CA'].data['p'] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
g.edges['CA'].data['mask'] = (g.edges['CA'].data['p'] != 0)
if mode == 'pure_gpu': if mode == 'pure_gpu':
g = g.to(F.cuda()) g = g.to(F.cuda())
for num_workers in [0, 1, 2] if mode == 'cpu' else [0]: for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment