Unverified Commit e56bbafd authored by Qidong Su's avatar Qidong Su Committed by GitHub
Browse files

[Feature] Biased Neighbor Sampling (#2987)



* update

* update

* update

* update

* lint

* lint

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* lint

* update

* clone

* update

* update

* update

* update

* replace idarray with ndarray

* refactor cpp part

* refactor python part

* debug

* refactor interface

* test and doc

* lint and test

* lint

* fix

* fix

* fix

* const

* doc

* fix

* fix

* fix

* fix

* fix & doc

* fix

* fix

* update

* update

* update

* merge

* doc

* doc

* lint

* fix

* more tests

* doc

* fix

* fix

* update

* update

* update

* fix

* fix
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 7415eaa5
...@@ -430,6 +430,71 @@ COOMatrix CSRRowWiseTopk( ...@@ -430,6 +430,71 @@ COOMatrix CSRRowWiseTopk(
FloatArray weight, FloatArray weight,
bool ascending = false); bool ascending = false);
/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently,
* where the probability of columns to be picked can be biased according to its tag.
*
* Each column is assigned an integer tag which determines its probability to be sampled.
* Users can assign different probability to different tags.
*
* This function only works with a CSR matrix sorted according to the tag so that entries with
* the same column tag are arranged in a consecutive range, and the input `tag_offset` represents
* the boundaries of these ranges. However, the function itself will not check if the input matrix
* has been sorted. It's the caller's responsibility to ensure the input matrix has been sorted
* by `CSRSortByTag` (it will also return a NDArray `tag_offset` which should be used as an input
* of this function).
*
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than num_samples,
* all the values are picked.
*
* Examples:
*
* // csr.num_rows = 4;
* // csr.num_cols = 4;
* // csr.indptr = [0, 2, 4, 5, 5]
* // csr.indices = [1, 2, 2, 3, 3]
* // tag of each element's column: 0, 0, 0, 1, 1
* // tag_offset = [[0, 2, 2], [0, 1, 2], [0, 0, 1]]
* // csr.data = [2, 3, 0, 1, 4]
* // bias = [1.0, 0.0]
* CSRMatrix mat = ...;
* IdArray rows = ...; //[0, 1]
* NDArray tag_offset = ...;
* FloatArray bias = ...;
* COOMatrix sampled = CSRRowWiseSamplingBiased(mat, rows, 1, bias);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [0, 1]
* // sampled.cols = [1, 2]
* // sampled.data = [2, 0]
* // Note that in this case, for row 1, the column 3 will never be picked as it has tag 1 and the
* // probability of tag 1 is 0.
*
*
* \param mat Input CSR matrix.
* \param rows Rows to sample from.
* \param num_samples Number of samples.
* \param tag_offset The boundaries of tags. Should be of the shape [num_row, num_tags+1]
* \param bias Unnormalized probability array. Should be of length num_tags
* \param replace True if sample with replacement
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*
*/
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace = true
);
/*! /*!
* \brief Sort the column index according to the tag of each column. * \brief Sort the column index according to the tag of each column.
* *
......
...@@ -69,6 +69,15 @@ HeteroSubgraph SampleNeighborsTopk( ...@@ -69,6 +69,15 @@ HeteroSubgraph SampleNeighborsTopk(
const std::vector<FloatArray>& weight, const std::vector<FloatArray>& weight,
bool ascending = false); bool ascending = false);
HeteroSubgraph SampleNeighborsBiased(
const HeteroGraphPtr hg,
const IdArray& nodes,
const int64_t fanouts,
const NDArray& bias,
const NDArray& tag_offset,
const EdgeDir dir,
const bool replace
);
} // namespace sampling } // namespace sampling
} // namespace dgl } // namespace dgl
......
...@@ -9,6 +9,7 @@ from .. import utils ...@@ -9,6 +9,7 @@ from .. import utils
__all__ = [ __all__ = [
'sample_neighbors', 'sample_neighbors',
'sample_neighbors_biased',
'select_topk'] 'select_topk']
def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
...@@ -179,6 +180,163 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -179,6 +180,163 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
return ret return ret
def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in',
tag_offset_name='_TAG_OFFSET', replace=False,
copy_ndata=True, copy_edata=True):
"""Sample neighboring edges of the given nodes and return the induced subgraph, where each
neighbor's probability to be picked is determined by its tag.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
will be randomly chosen. The graph returned will then contain all the nodes in the
original graph, but only the sampled edges.
This version of neighbor sampling can support the scenario where adjacent nodes with different
types might have different probability to be picked. Each node is assigned an integer(tag)
which represents its type. Tag is an analogue of node type under the framework of homogeneous
graphs. Nodes with the same tag share the same probability.
For example, assume a node has (a+b) neighbors, and a of them have tag 0 while b of them have
tag 1. Assume a node of tag 0 has an unnormalized probability p to be picked while a node of
tag 1 has q. This function first chooses a tag according to the unnormalized probability
distribution (ap, bq), and then run a uniform sampling within the nodes with the chosen tag.
In order to sample efficiently, we need to first sort the CSR matrix of the graph
according to the tag (See `dgl.transform.sort_in_edges` and `dgl.transform.sort_out_edges`
for details), which will arrange the neighbors with the same tag in a consecutive range
and store the offset of these ranges in a node feature with tag_offset_name as its name.
Please make sure that the graph has been sorted by the sorting function corresponding to
the edge direction ('in' or 'out'). This function itself will not check whether the graph is
sorted. Note that the input `tag_offset_name` should be consistent with that in the sorting
function.
Only homogeneous or bipartite graphs are supported. For bipartite graphs, only candidate
frontier nodes have tags(source nodes when edge_dir='in' and destination nodes when
edge_dir='out'), and the offset of tags should be stored as a node feature of the seed nodes.
Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph.
Parameters
----------
g : DGLGraph
The graph. Must be homogeneous or bipartite (only one edge type). Must be on CPU.
nodes : tensor or list
Node IDs to sample neighbors from.
fanout : int
The number of edges to be sampled for each node on each edge type.
If -1 is given, all the neighboring edges will be selected.
bias : tensor or list
The (unnormalized) probabilities associated with each tag. Its length should be equal
to the number of tags.
Entries of this array must be non-negative floats, and the sum of the entries must be
positive (though they don't have to sum up to one). Otherwise, the result will be
undefined.
edge_dir : str, optional
Determines whether to sample inbound or outbound edges.
Can take either ``in`` for inbound edges or ``out`` for outbound edges.
tag_offset_name : str, optional
The name of the node feature storing tag offsets.
(Default: "_TAG_OFFSET")
replace : bool, optional
If True, sample with replacement.
copy_ndata: bool, optional
If True, the node features of the new graph are copied from
the original graph. If False, the new graph will not have any
node features.
(Default: True)
copy_edata: bool, optional
If True, the edge features of the new graph are copied from
the original graph. If False, the new graph will not have any
edge features.
(Default: True)
Returns
-------
DGLGraph
A sampled subgraph containing only the sampled neighboring edges. It is on CPU.
Notes
-----
If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as
the node or edge features of the original graph and the new graph.
As a result, users should avoid performing in-place operations
on the node features of the new graph to avoid feature corruption.
Examples
--------
Assume that you have the following graph
>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))
And the tags
>>> tag = torch.IntTensor([0, 0, 1])
Sort the graph (necessary!)
>>> g_sorted = dgl.transform.sort_out_edges(g, tag)
>>> g_sorted.ndata['_TAG_OFFSET']
tensor([[0, 1, 2],
[0, 2, 2],
[0, 1, 2]])
Set the probability of each tag:
>>> bias = torch.tensor([1.0, 0.001])
# node 2 is almost impossible to be sampled because it has tag 1.
To sample one out bound edge for node 0 and node 2:
>>> sg = dgl.sampling.sample_neighbors_biased(g_sorted, [0, 2], 1, bias, edge_dir='out')
>>> sg.edges(order='eid')
(tensor([0, 2]), tensor([1, 0]))
>>> sg.edata[dgl.EID]
tensor([0, 5])
With ``fanout`` greater than the number of actual neighbors and without replacement,
DGL will take all neighbors instead:
>>> sg = dgl.sampling.sample_neighbors_biased(g_sorted, [0, 2], 3, bias, edge_dir='out')
>>> sg.edges(order='eid')
(tensor([0, 0, 2, 2]), tensor([1, 2, 0, 2]))
"""
if isinstance(nodes, list):
nodes = F.tensor(nodes)
if isinstance(bias, list):
bias = F.tensor(bias)
nodes_array = F.to_dgl_nd(nodes)
bias_array = F.to_dgl_nd(bias)
if edge_dir == 'in':
tag_offset_array = F.to_dgl_nd(g.dstdata[tag_offset_name])
elif edge_dir == 'out':
tag_offset_array = F.to_dgl_nd(g.srcdata[tag_offset_name])
else:
raise DGLError("edge_dir can only be 'in' or 'out'")
subgidx = _CAPI_DGLSampleNeighborsBiased(g._graph, nodes_array, fanout, bias_array,
tag_offset_array, edge_dir, replace)
induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
if copy_ndata:
node_frames = utils.extract_node_subframes(g, None)
utils.set_new_frames(ret, node_frames=node_frames)
if copy_edata:
edge_frames = utils.extract_edge_subframes(g, induced_edges)
utils.set_new_frames(ret, edge_frames=edge_frames)
ret.edata[EID] = induced_edges[0]
return ret
def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False, def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False,
copy_ndata=True, copy_edata=True): copy_ndata=True, copy_edata=True):
"""Select the neighboring edges with k-largest (or k-smallest) weights of the given """Select the neighboring edges with k-largest (or k-smallest) weights of the given
......
...@@ -580,6 +580,23 @@ COOMatrix CSRRowWiseTopk( ...@@ -580,6 +580,23 @@ COOMatrix CSRRowWiseTopk(
return ret; return ret;
} }
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace) {
COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSamplingBiased", {
ATEN_FLOAT_TYPE_SWITCH(bias->dtype, FloatType, "bias", {
ret = impl::CSRRowWiseSamplingBiased<XPU, IdType, FloatType>(
mat, rows, num_samples, tag_offset, bias, replace);
});
});
return ret;
}
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) { CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
CSRMatrix ret; CSRMatrix ret;
......
...@@ -173,6 +173,16 @@ template <DLDeviceType XPU, typename IdType, typename DType> ...@@ -173,6 +173,16 @@ template <DLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseTopk( COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending); CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace
);
// Union CSRMatrixes // Union CSRMatrixes
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs); CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);
......
...@@ -59,6 +59,24 @@ inline PickFn<IdxType> GetSamplingUniformPickFn( ...@@ -59,6 +59,24 @@ inline PickFn<IdxType> GetSamplingUniformPickFn(
}; };
return pick_fn; return pick_fn;
} }
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn(
int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
PickFn<IdxType> pick_fn = [num_samples, split, bias, replace]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
const IdxType *tag_offset = static_cast<IdxType *>(split->data) + rowid * split->shape[1];
RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
num_samples, tag_offset, bias, out_idx, replace);
for (int64_t j = 0; j < num_samples; ++j) {
out_idx[j] += off;
}
};
return pick_fn;
}
} // namespace } // namespace
/////////////////////////////// CSR /////////////////////////////// /////////////////////////////// CSR ///////////////////////////////
...@@ -92,6 +110,33 @@ template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int32_t>( ...@@ -92,6 +110,33 @@ template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int32_t>(
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>( template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>(
CSRMatrix, IdArray, int64_t, bool); CSRMatrix, IdArray, int64_t, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace
) {
auto pick_fn = GetSamplingBiasedPickFn<IdxType, FloatType>(
num_samples, tag_offset, bias, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
/////////////////////////////// COO /////////////////////////////// /////////////////////////////// COO ///////////////////////////////
template <DLDeviceType XPU, typename IdxType, typename FloatType> template <DLDeviceType XPU, typename IdxType, typename FloatType>
......
...@@ -190,6 +190,85 @@ HeteroSubgraph SampleNeighborsTopk( ...@@ -190,6 +190,85 @@ HeteroSubgraph SampleNeighborsTopk(
return ret; return ret;
} }
HeteroSubgraph SampleNeighborsBiased(
const HeteroGraphPtr hg,
const IdArray& nodes,
const int64_t fanout,
const NDArray& bias,
const NDArray& tag_offset,
const EdgeDir dir,
const bool replace
) {
CHECK_EQ(hg->NumEdgeTypes(), 1) << "Only homogeneous or bipartite graphs are supported";
auto pair = hg->meta_graph()->FindEdge(0);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const dgl_type_t nodes_ntype = (dir == EdgeDir::kOut) ? src_vtype : dst_vtype;
// sanity check
CHECK_EQ(tag_offset->ndim, 2) << "The shape of tag_offset should be [num_nodes, num_tags + 1]";
CHECK_EQ(tag_offset->shape[0], hg->NumVertices(nodes_ntype))
<< "The shape of tag_offset should be [num_nodes, num_tags + 1]";
CHECK_EQ(tag_offset->shape[1], bias->shape[0] + 1)
<< "The sizes of tag_offset and bias are inconsistent";
const int64_t num_nodes = nodes->shape[0];
HeteroGraphPtr subrel;
IdArray induced_edges;
const dgl_type_t etype = 0;
if (num_nodes == 0 || fanout == 0) {
// Nothing to sample for this etype, create a placeholder relation graph
subrel = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context());
induced_edges = aten::NullArray();
} else if (fanout == -1) {
const auto &earr = (dir == EdgeDir::kOut) ?
hg->OutEdges(etype, nodes_ntype) :
hg->InEdges(etype, nodes_ntype);
subrel = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges = earr.id;
} else {
// sample from one relation graph
const auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
const auto created_fmt = hg->GetCreatedFormats();
COOMatrix sampled_coo;
switch (req_fmt) {
case CSR_CODE:
CHECK(created_fmt & CSR_CODE) << "A sorted CSR Matrix is required.";
sampled_coo = aten::CSRRowWiseSamplingBiased(
hg->GetCSRMatrix(etype), nodes, fanout, tag_offset, bias, replace);
break;
case CSC_CODE:
CHECK(created_fmt & CSC_CODE) << "A sorted CSC Matrix is required.";
sampled_coo = aten::CSRRowWiseSamplingBiased(
hg->GetCSCMatrix(etype), nodes, fanout, tag_offset, bias, replace);
sampled_coo = aten::COOTranspose(sampled_coo);
break;
default:
LOG(FATAL) << "Unsupported sparse format.";
}
subrel = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows, sampled_coo.num_cols,
sampled_coo.row, sampled_coo.col);
induced_edges = sampled_coo.data;
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), {subrel}, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = {induced_edges};
return ret;
}
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors") DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
...@@ -232,5 +311,26 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk") ...@@ -232,5 +311,26 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
*rv = HeteroGraphRef(subg); *rv = HeteroGraphRef(subg);
}); });
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsBiased")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const IdArray nodes = args[1];
const int64_t fanout = args[2];
const NDArray bias = args[3];
const NDArray tag_offset = args[4];
const std::string dir_str = args[5];
const bool replace = args[6];
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsBiased(
hg.sptr(), nodes, fanout, bias, tag_offset, dir, replace);
*rv = HeteroGraphRef(subg);
});
} // namespace sampling } // namespace sampling
} // namespace dgl } // namespace dgl
...@@ -588,12 +588,115 @@ def test_sample_neighbors_with_0deg(): ...@@ -588,12 +588,115 @@ def test_sample_neighbors_with_0deg():
sg = dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='out', replace=True) sg = dgl.sampling.sample_neighbors(g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir='out', replace=True)
assert sg.number_of_edges() == 0 assert sg.number_of_edges() == 0
def create_test_graph(num_nodes, num_edges_per_node, bipartite=False):
src = np.concatenate(
[np.array([i] * num_edges_per_node) for i in range(num_nodes)])
dst = np.concatenate(
[np.random.choice(num_nodes, num_edges_per_node, replace=False) for i in range(num_nodes)]
)
if bipartite:
g = dgl.heterograph({("u", "e", "v") : (src, dst)})
else:
g = dgl.graph((src, dst))
return g
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_biased_homogeneous():
g = create_test_graph(100, 30)
def check_num(nodes, tag):
nodes, tag = F.asnumpy(nodes), F.asnumpy(tag)
cnt = [sum(tag[nodes] == i) for i in range(4)]
# No tag 0
assert cnt[0] == 0
# very rare tag 1
assert cnt[2] > 2 * cnt[1]
assert cnt[3] > 2 * cnt[1]
tag = F.tensor(np.random.choice(4, 100))
bias = F.tensor([0, 0.1, 10, 10], dtype=F.float32)
# inedge / without replacement
g_sorted = dgl.sort_in_edges(g, tag)
for _ in range(5):
subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.nodes(), 5, bias, replace=False)
check_num(subg.edges()[0], tag)
u, v = subg.edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert len(edge_set) == subg.number_of_edges()
# inedge / with replacement
for _ in range(5):
subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.nodes(), 5, bias, replace=True)
check_num(subg.edges()[0], tag)
# outedge / without replacement
g_sorted = dgl.sort_out_edges(g, tag)
for _ in range(5):
subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.nodes(), 5, bias, edge_dir='out', replace=False)
check_num(subg.edges()[1], tag)
u, v = subg.edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert len(edge_set) == subg.number_of_edges()
# outedge / with replacement
for _ in range(5):
subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.nodes(), 5, bias, edge_dir='out', replace=True)
check_num(subg.edges()[1], tag)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_biased_bipartite():
g = create_test_graph(100, 30, True)
num_dst = g.number_of_dst_nodes()
bias = F.tensor([0, 0.01, 10, 10], dtype=F.float32)
def check_num(nodes, tag):
nodes, tag = F.asnumpy(nodes), F.asnumpy(tag)
cnt = [sum(tag[nodes] == i) for i in range(4)]
# No tag 0
assert cnt[0] == 0
# very rare tag 1
assert cnt[2] > 2 * cnt[1]
assert cnt[3] > 2 * cnt[1]
# inedge / without replacement
tag = F.tensor(np.random.choice(4, 100))
g_sorted = dgl.sort_in_edges(g, tag)
for _ in range(5):
subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.dstnodes(), 5, bias, replace=False)
check_num(subg.edges()[0], tag)
u, v = subg.edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert len(edge_set) == subg.number_of_edges()
# inedge / with replacement
for _ in range(5):
subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.dstnodes(), 5, bias, replace=True)
check_num(subg.edges()[0], tag)
# outedge / without replacement
tag = F.tensor(np.random.choice(4, num_dst))
g_sorted = dgl.sort_out_edges(g, tag)
for _ in range(5):
subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.srcnodes(), 5, bias, edge_dir='out', replace=False)
check_num(subg.edges()[1], tag)
u, v = subg.edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert len(edge_set) == subg.number_of_edges()
# outedge / with replacement
for _ in range(5):
subg = dgl.sampling.sample_neighbors_biased(g_sorted, g.srcnodes(), 5, bias, edge_dir='out', replace=True)
check_num(subg.edges()[1], tag)
if __name__ == '__main__': if __name__ == '__main__':
test_random_walk() test_random_walk()
test_pack_traces() test_pack_traces()
test_pinsage_sampling() test_pinsage_sampling()
test_sample_neighbors() # test_sample_neighbors()
test_sample_neighbors_outedge() test_sample_neighbors_outedge()
test_sample_neighbors_topk() test_sample_neighbors_topk()
test_sample_neighbors_topk_outedge() test_sample_neighbors_topk_outedge()
test_sample_neighbors_with_0deg() test_sample_neighbors_with_0deg()
test_sample_neighbors_biased_homogeneous()
test_sample_neighbors_biased_bipartite()
...@@ -123,6 +123,9 @@ void _TestCSRSampling(bool has_data) { ...@@ -123,6 +123,9 @@ void _TestCSRSampling(bool has_data) {
} }
} }
TEST(RowwiseTest, TestCSRSampling) { TEST(RowwiseTest, TestCSRSampling) {
_TestCSRSampling<int32_t, float>(true); _TestCSRSampling<int32_t, float>(true);
_TestCSRSampling<int64_t, float>(true); _TestCSRSampling<int64_t, float>(true);
...@@ -356,3 +359,64 @@ TEST(RowwiseTest, TestCOOTopk) { ...@@ -356,3 +359,64 @@ TEST(RowwiseTest, TestCOOTopk) {
_TestCOOTopk<int32_t, double>(false); _TestCOOTopk<int32_t, double>(false);
_TestCOOTopk<int64_t, double>(false); _TestCOOTopk<int64_t, double>(false);
} }
template <typename Idx, typename FloatType>
void _TestCSRSamplingBiased(bool has_data) {
auto mat = CSR<Idx>(has_data);
// 0 - 0,1
// 1 - 1
// 3 - 2,3
NDArray tag_offset = NDArray::FromVector(
std::vector<Idx>({0, 1, 2,
0, 0, 1,
0, 0, 0,
0, 1, 2}));
tag_offset = tag_offset.CreateView({4, 3}, tag_offset->dtype);
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 1, 3}));
FloatArray bias = NDArray::FromVector(
std::vector<FloatType>({0, 0.5})
);
for (int k = 0 ; k < 10 ; ++k) {
auto rst = CSRRowWiseSamplingBiased(mat, rows, 1, tag_offset, bias, false);
CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 0)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 2)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
}
}
for (int k = 0 ; k < 10 ; ++k) {
auto rst = CSRRowWiseSamplingBiased(mat, rows, 3, tag_offset, bias, true);
CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 0)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 1)));
} else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 2)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3)));
}
}
}
TEST(RowwiseTest, TestCSRSamplingBiased) {
_TestCSRSamplingBiased<int32_t, float>(true);
_TestCSRSamplingBiased<int32_t, float>(false);
_TestCSRSamplingBiased<int64_t, float>(true);
_TestCSRSamplingBiased<int64_t, float>(false);
_TestCSRSamplingBiased<int32_t, double>(true);
_TestCSRSamplingBiased<int32_t, double>(false);
_TestCSRSamplingBiased<int64_t, double>(true);
_TestCSRSamplingBiased<int64_t, double>(false);
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment