Commit fb4246e5 authored by GaiYu0's avatar GaiYu0 Committed by Da Zheng
Browse files

[Feature]Uniform layer-wise sampler (#416)

* migrate to node-flow

* uniform layer sampler test cases

* more test cases

* documentations

* fix lint errors

* fix lint errors

* fix lint errors

* iota

* add asnumpy

* requested changes

* fix indptr error

* fix lint errors

* requested changes & fix lint errors

* fix lint errors

* fix LayerSampler unit test
parent a88f3511
...@@ -18,6 +18,7 @@ typedef dgl::runtime::NDArray IdArray; ...@@ -18,6 +18,7 @@ typedef dgl::runtime::NDArray IdArray;
typedef dgl::runtime::NDArray DegreeArray; typedef dgl::runtime::NDArray DegreeArray;
typedef dgl::runtime::NDArray BoolArray; typedef dgl::runtime::NDArray BoolArray;
typedef dgl::runtime::NDArray IntArray; typedef dgl::runtime::NDArray IntArray;
typedef dgl::runtime::NDArray FloatArray;
struct Subgraph; struct Subgraph;
struct NodeFlow; struct NodeFlow;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifndef DGL_SAMPLER_H_ #ifndef DGL_SAMPLER_H_
#define DGL_SAMPLER_H_ #define DGL_SAMPLER_H_
#include <vector>
#include <string> #include <string>
#include "graph_interface.h" #include "graph_interface.h"
...@@ -61,6 +62,20 @@ class SamplerOp { ...@@ -61,6 +62,20 @@ class SamplerOp {
int num_hops, int expand_factor, int num_hops, int expand_factor,
const bool add_self_loop); const bool add_self_loop);
/*!
* \brief Sample a graph from the seed vertices with layer sampling.
* The layers are sampled with a uniform distribution.
*
* \param graphs A graph for sampling.
* \param seeds the nodes where we should start to sample.
* \param edge_type the type of edges we should sample neighbors.
* \param layer_sizes The size of layers.
* \return a NodeFlow graph.
*/
static NodeFlow LayerUniformSample(const ImmutableGraph *graph, IdArray seed_array,
const std::string &neigh_type,
const std::vector<size_t> &layer_sizes);
/*! /*!
* \brief Batch-generate random walk traces * \brief Batch-generate random walk traces
* \param seeds The array of starting vertex IDs * \param seeds The array of starting vertex IDs
......
from .sampler import NeighborSampler from .sampler import NeighborSampler, LayerSampler
from .randomwalk import * from .randomwalk import *
...@@ -14,18 +14,25 @@ try: ...@@ -14,18 +14,25 @@ try:
except ImportError: except ImportError:
import queue import queue
__all__ = ['NeighborSampler'] __all__ = ['NeighborSampler', 'LayerSampler']
class NSSubgraphLoader(object): class SampledSubgraphLoader(object):
def __init__(self, g, batch_size, expand_factor, num_hops=1, def __init__(self, g, batch_size, sampler,
expand_factor=None, num_hops=1, layer_sizes=None,
neighbor_type='in', node_prob=None, seed_nodes=None, neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, add_self_loop=False): shuffle=False, num_workers=1, add_self_loop=False):
self._g = g self._g = g
if not g._graph.is_readonly(): if not g._graph.is_readonly():
raise NotImplementedError("NodeFlow loader only support read-only graphs.") raise NotImplementedError("NodeFlow loader only support read-only graphs.")
self._batch_size = batch_size self._batch_size = batch_size
self._sampler = sampler
if sampler == 'neighbor':
self._expand_factor = expand_factor self._expand_factor = expand_factor
self._num_hops = num_hops self._num_hops = num_hops
elif sampler == 'layer':
self._layer_sizes = layer_sizes
else:
raise NotImplementedError()
self._node_prob = node_prob self._node_prob = node_prob
self._add_self_loop = add_self_loop self._add_self_loop = add_self_loop
if self._node_prob is not None: if self._node_prob is not None:
...@@ -54,9 +61,13 @@ class NSSubgraphLoader(object): ...@@ -54,9 +61,13 @@ class NSSubgraphLoader(object):
end = min((self._nflow_idx + 1) * self._batch_size, num_nodes) end = min((self._nflow_idx + 1) * self._batch_size, num_nodes)
seed_ids.append(utils.toindex(self._seed_nodes[start:end])) seed_ids.append(utils.toindex(self._seed_nodes[start:end]))
self._nflow_idx += 1 self._nflow_idx += 1
if self._sampler == 'neighbor':
sgi = self._g._graph.neighbor_sampling(seed_ids, self._expand_factor, sgi = self._g._graph.neighbor_sampling(seed_ids, self._expand_factor,
self._num_hops, self._neighbor_type, self._num_hops, self._neighbor_type,
self._node_prob, self._add_self_loop) self._node_prob, self._add_self_loop)
elif self._sampler == 'layer':
sgi = self._g._graph.layer_sampling(seed_ids, self._layer_sizes,
self._neighbor_type, self._node_prob)
nflows = [NodeFlow(self._g, i) for i in sgi] nflows = [NodeFlow(self._g, i) for i in sgi]
self._nflows.extend(nflows) self._nflows.extend(nflows)
...@@ -264,8 +275,51 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -264,8 +275,51 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
generator generator
The generator of NodeFlows. The generator of NodeFlows.
''' '''
loader = NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob, loader = SampledSubgraphLoader(g, batch_size, 'neighbor',
seed_nodes, shuffle, num_workers, add_self_loop) expand_factor=expand_factor, num_hops=num_hops,
neighbor_type=neighbor_type, node_prob=node_prob,
seed_nodes=seed_nodes, shuffle=shuffle,
num_workers=num_workers)
if not prefetch:
return loader
else:
return _PrefetchingLoader(loader, num_prefetch=num_workers*2)
def LayerSampler(g, batch_size, layer_sizes,
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, prefetch=False):
'''Create a sampler that samples neighborhood.
This creates a NodeFlow loader that samples subgraphs from the input graph
with layer-wise sampling. This sampling method is implemented in C and can perform
sampling very efficiently.
The NodeFlow loader returns a list of NodeFlows.
The size of the NodeFlow list is the number of workers.
Parameters
----------
g: the DGLGraph where we sample NodeFlows.
batch_size: The number of NodeFlows in a batch.
layer_size: A list of layer sizes.
node_prob: the probability that a neighbor node is sampled.
Not implemented.
seed_nodes: a list of nodes where we sample NodeFlows from.
If it's None, the seed vertices are all vertices in the graph.
shuffle: indicates the sampled NodeFlows are shuffled.
num_workers: the number of worker threads that sample NodeFlows in parallel.
prefetch : bool, default False
Whether to prefetch the samples in the next batch.
Returns
-------
A NodeFlow iterator
The iterator returns a list of batched NodeFlows.
'''
loader = SampledSubgraphLoader(g, batch_size, 'layer', layer_sizes=layer_sizes,
neighbor_type=neighbor_type, node_prob=node_prob,
seed_nodes=seed_nodes, shuffle=shuffle,
num_workers=num_workers)
if not prefetch: if not prefetch:
return loader return loader
else: else:
......
...@@ -10,6 +10,7 @@ from ._ffi.base import c_array ...@@ -10,6 +10,7 @@ from ._ffi.base import c_array
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
from . import backend as F from . import backend as F
from . import ndarray as nd
from . import utils from . import utils
GraphIndexHandle = ctypes.c_void_p GraphIndexHandle = ctypes.c_void_p
...@@ -694,6 +695,25 @@ class GraphIndex(object): ...@@ -694,6 +695,25 @@ class GraphIndex(object):
utils.toindex(rst(num_subgs * 3 + i)), utils.toindex(rst(num_subgs * 3 + i)),
utils.toindex(rst(num_subgs * 4 + i))) for i in range(num_subgs)] utils.toindex(rst(num_subgs * 4 + i))) for i in range(num_subgs)]
def layer_sampling(self, seed_ids, layer_sizes, neighbor_type, node_prob=None):
"""Layer sampling"""
if len(seed_ids) == 0:
return []
seed_ids = [v.todgltensor() for v in seed_ids]
layer_sizes = nd.from_dlpack(F.zerocopy_to_dlpack(F.tensor(layer_sizes)))
if node_prob is None:
rst = _layer_uniform_sampling(self, seed_ids, neighbor_type, layer_sizes)
else:
raise NotImplementedError()
num_subgs = len(seed_ids)
return [NodeFlowIndex(rst(i), self, utils.toindex(rst(num_subgs + i)),
utils.toindex(rst(num_subgs * 2 + i)),
utils.toindex(rst(num_subgs * 3 + i)),
utils.toindex(rst(num_subgs * 4 + i))) for i in range(num_subgs)]
def random_walk(self, seeds, num_traces, num_hops): def random_walk(self, seeds, num_traces, num_hops):
"""Random walk sampling. """Random walk sampling.
...@@ -1151,3 +1171,25 @@ def _uniform_sampling(gidx, seed_ids, neigh_type, num_hops, expand_factor, add_s ...@@ -1151,3 +1171,25 @@ def _uniform_sampling(gidx, seed_ids, neigh_type, num_hops, expand_factor, add_s
return _NEIGHBOR_SAMPLING_APIS[len(seed_ids)](gidx._handle, *seed_ids, neigh_type, return _NEIGHBOR_SAMPLING_APIS[len(seed_ids)](gidx._handle, *seed_ids, neigh_type,
num_hops, expand_factor, num_seeds, num_hops, expand_factor, num_seeds,
add_self_loop) add_self_loop)
_LAYER_SAMPLING_APIS = {
1: _CAPI_DGLGraphLayerUniformSampling,
2: _CAPI_DGLGraphLayerUniformSampling2,
4: _CAPI_DGLGraphLayerUniformSampling4,
8: _CAPI_DGLGraphLayerUniformSampling8,
16: _CAPI_DGLGraphLayerUniformSampling16,
32: _CAPI_DGLGraphLayerUniformSampling32,
64: _CAPI_DGLGraphLayerUniformSampling64,
128: _CAPI_DGLGraphLayerUniformSampling128,
}
def _layer_uniform_sampling(gidx, seed_ids, neigh_type, layer_sizes):
num_seeds = len(seed_ids)
empty_ids = []
if len(seed_ids) > 1 and len(seed_ids) not in _LAYER_SAMPLING_APIS.keys():
remain = 2**int(math.ceil(math.log2(len(dgl_ids)))) - len(dgl_ids)
empty_ids = _EMPTY_ARRAYS[0:remain]
seed_ids.extend([empty.todgltensor() for empty in empty_ids])
assert len(seed_ids) in _LAYER_SAMPLING_APIS.keys()
return _LAYER_SAMPLING_APIS[len(seed_ids)](gidx._handle, *seed_ids, neigh_type,
layer_sizes, num_seeds)
...@@ -474,6 +474,48 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling64") ...@@ -474,6 +474,48 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling64")
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling128") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling128")
.set_body(CAPI_NeighborUniformSample<128>); .set_body(CAPI_NeighborUniformSample<128>);
template<int num_seeds>
void CAPI_LayerUniformSample(DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
std::vector<IdArray> seeds(num_seeds);
for (size_t i = 0; i < seeds.size(); i++)
seeds[i] = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[i + 1]));
std::string neigh_type = args[num_seeds + 1];
auto ls_array = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[num_seeds + 2]));
size_t *ls_data = static_cast<size_t*>(ls_array->data);
size_t ls_len = ls_array->shape[0];
std::vector<size_t> layer_sizes;
std::copy(ls_data, ls_data + ls_len, std::back_inserter(layer_sizes));
const int num_valid_seeds = args[num_seeds + 3];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr);
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(num_valid_seeds <= num_seeds);
std::vector<NodeFlow> subgs(seeds.size());
#pragma omp parallel for
for (int i = 0; i < num_valid_seeds; i++) {
subgs[i] = SamplerOp::LayerUniformSample(gptr, seeds[i], neigh_type, layer_sizes);
}
*rv = ConvertSubgraphToPackedFunc(subgs);
}
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLayerUniformSampling")
.set_body(CAPI_LayerUniformSample<1>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLayerUniformSampling2")
.set_body(CAPI_LayerUniformSample<2>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLayerUniformSampling4")
.set_body(CAPI_LayerUniformSample<4>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLayerUniformSampling8")
.set_body(CAPI_LayerUniformSample<8>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLayerUniformSampling16")
.set_body(CAPI_LayerUniformSample<16>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLayerUniformSampling32")
.set_body(CAPI_LayerUniformSample<32>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLayerUniformSampling64")
.set_body(CAPI_LayerUniformSample<64>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLayerUniformSampling128")
.set_body(CAPI_LayerUniformSample<128>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <algorithm> #include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <cmath> #include <cmath>
#include <numeric>
#ifdef _MSC_VER #ifdef _MSC_VER
// rand in MS compiler works well in multi-threading. // rand in MS compiler works well in multi-threading.
...@@ -531,4 +532,187 @@ IdArray SamplerOp::RandomWalk( ...@@ -531,4 +532,187 @@ IdArray SamplerOp::RandomWalk(
return traces; return traces;
} }
namespace {
void ConstructLayers(const int64_t *indptr,
const dgl_id_t *indices,
const IdArray seed_array,
const std::vector<size_t> &layer_sizes,
std::vector<dgl_id_t> *layer_offsets,
std::vector<dgl_id_t> *node_mapping,
std::vector<int64_t> *actl_layer_sizes,
std::vector<float> *probabilities) {
/*
* Given a graph and a collection of seed nodes, this function constructs NodeFlow
* layers via uniform layer-wise sampling, and return the resultant layers and their
* corresponding probabilities.
*/
const dgl_id_t* seed_data = static_cast<dgl_id_t*>(seed_array->data);
size_t seed_len = seed_array->shape[0];
std::copy(seed_data, seed_data + seed_len, std::back_inserter(*node_mapping));
actl_layer_sizes->push_back(node_mapping->size());
probabilities->insert(probabilities->end(), node_mapping->size(), 1);
size_t curr = 0;
size_t next = node_mapping->size();
unsigned int rand_seed = time(nullptr);
for (auto i = layer_sizes.rbegin(); i != layer_sizes.rend(); ++i) {
auto layer_size = *i;
std::unordered_set<dgl_id_t> candidate_set;
for (auto j = curr; j != next; ++j) {
auto src = (*node_mapping)[j];
candidate_set.insert(indices + indptr[src], indices + indptr[src + 1]);
}
std::vector<dgl_id_t> candidate_vector;
std::copy(candidate_set.begin(), candidate_set.end(),
std::back_inserter(candidate_vector));
std::unordered_map<dgl_id_t, size_t> n_occurrences;
auto n_candidates = candidate_vector.size();
for (size_t j = 0; j != layer_size; ++j) {
auto dst = candidate_vector[rand_r(&rand_seed) % n_candidates];
if (!n_occurrences.insert(std::make_pair(dst, 1)).second) {
++n_occurrences[dst];
}
}
for (auto const &pair : n_occurrences) {
node_mapping->push_back(pair.first);
float p = pair.second * n_candidates / static_cast<float>(layer_size);
probabilities->push_back(p);
}
actl_layer_sizes->push_back(node_mapping->size() - next);
curr = next;
next = node_mapping->size();
}
std::reverse(node_mapping->begin(), node_mapping->end());
std::reverse(actl_layer_sizes->begin(), actl_layer_sizes->end());
layer_offsets->push_back(0);
for (const auto &size : *actl_layer_sizes) {
layer_offsets->push_back(size + layer_offsets->back());
}
}
void ConstructFlows(const int64_t *indptr,
const dgl_id_t *indices,
const dgl_id_t *eids,
const std::vector<dgl_id_t> &node_mapping,
const std::vector<int64_t> &actl_layer_sizes,
std::vector<int64_t> *sub_indptr,
std::vector<dgl_id_t> *sub_indices,
std::vector<dgl_id_t> *sub_eids,
std::vector<dgl_id_t> *flow_offsets,
std::vector<dgl_id_t> *edge_mapping) {
/*
* Given a graph and a sequence of NodeFlow layers, this function constructs dense
* subgraphs (flows) between consecutive layers.
*/
auto n_flows = actl_layer_sizes.size() - 1;
sub_indptr->insert(sub_indptr->end(), actl_layer_sizes.front() + 1, 0);
flow_offsets->push_back(0);
int64_t first = 0;
for (size_t i = 0; i < n_flows; ++i) {
auto src_size = actl_layer_sizes[i];
std::unordered_map<dgl_id_t, dgl_id_t> source_map;
for (int64_t j = 0; j < src_size; ++j) {
source_map.insert(std::make_pair(node_mapping[first + j], first + j));
}
auto dst_size = actl_layer_sizes[i + 1];
for (int64_t j = 0; j < dst_size; ++j) {
auto dst = node_mapping[first + src_size + j];
typedef std::pair<dgl_id_t, dgl_id_t> id_pair;
std::vector<id_pair> neighbor_indices;
for (int64_t k = indptr[dst]; k < indptr[dst + 1]; ++k) {
// TODO(gaiyu): accelerate hash table lookup
auto ret = source_map.find(indices[k]);
if (ret != source_map.end()) {
neighbor_indices.push_back(std::make_pair(ret->second, eids[k]));
}
}
auto cmp = [](const id_pair p, const id_pair q)->bool { return p.first < q.first; };
std::sort(neighbor_indices.begin(), neighbor_indices.end(), cmp);
for (const auto &pair : neighbor_indices) {
sub_indices->push_back(pair.first);
edge_mapping->push_back(pair.second);
}
sub_indptr->push_back(sub_indices->size());
}
flow_offsets->push_back(sub_indices->size());
first += src_size;
}
sub_eids->resize(sub_indices->size());
std::iota(sub_eids->begin(), sub_eids->end(), 0);
}
} // namespace
NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
const IdArray seed_array,
const std::string &neighbor_type,
const std::vector<size_t> &layer_sizes) {
const auto g_csr = neighbor_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
const int64_t *indptr = g_csr->indptr.data();
const dgl_id_t *indices = g_csr->indices.data();
const dgl_id_t *eids = g_csr->edge_ids.data();
std::vector<dgl_id_t> layer_offsets;
std::vector<dgl_id_t> node_mapping;
std::vector<int64_t> actl_layer_sizes;
std::vector<float> probabilities;
ConstructLayers(indptr,
indices,
seed_array,
layer_sizes,
&layer_offsets,
&node_mapping,
&actl_layer_sizes,
&probabilities);
NodeFlow nf;
int64_t n_nodes = node_mapping.size();
// TODO(gaiyu): a better estimate for the expected number of nodes
auto sub_csr = std::make_shared<ImmutableGraph::CSR>(n_nodes, n_nodes);
sub_csr->indptr.clear(); // TODO(zhengda): Why indptr.resize(num_vertices + 1)?
std::vector<dgl_id_t> flow_offsets;
std::vector<dgl_id_t> edge_mapping;
ConstructFlows(indptr,
indices,
eids,
node_mapping,
actl_layer_sizes,
&(sub_csr->indptr),
&(sub_csr->indices),
&(sub_csr->edge_ids),
&flow_offsets,
&edge_mapping);
if (neighbor_type == "in") {
nf.graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr, graph->IsMultigraph()));
} else {
nf.graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr, graph->IsMultigraph()));
}
nf.node_mapping = IdArray::Empty({n_nodes},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
nf.edge_mapping = IdArray::Empty({static_cast<int64_t>(edge_mapping.size())},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
nf.layer_offsets = IdArray::Empty({static_cast<int64_t>(layer_offsets.size())},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
nf.flow_offsets = IdArray::Empty({static_cast<int64_t>(flow_offsets.size())},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
std::copy(node_mapping.begin(), node_mapping.end(),
static_cast<dgl_id_t*>(nf.node_mapping->data));
std::copy(edge_mapping.begin(), edge_mapping.end(),
static_cast<dgl_id_t*>(nf.edge_mapping->data));
std::copy(layer_offsets.begin(), layer_offsets.end(),
static_cast<dgl_id_t*>(nf.layer_offsets->data));
std::copy(flow_offsets.begin(), flow_offsets.end(),
static_cast<dgl_id_t*>(nf.flow_offsets->data));
return nf;
}
} // namespace dgl } // namespace dgl
...@@ -71,17 +71,21 @@ def check_basic(g, nf): ...@@ -71,17 +71,21 @@ def check_basic(g, nf):
def test_basic(): def test_basic():
num_layers = 2 num_layers = 2
g = generate_rand_graph(100, connect_more=True) g = generate_rand_graph(100, connect_more=True)
print(0, 0)
nf = create_full_node_flow(g, num_layers) nf = create_full_node_flow(g, num_layers)
print(0, 1)
assert nf.number_of_nodes() == g.number_of_nodes() * (num_layers + 1) assert nf.number_of_nodes() == g.number_of_nodes() * (num_layers + 1)
assert nf.number_of_edges() == g.number_of_edges() * num_layers assert nf.number_of_edges() == g.number_of_edges() * num_layers
assert nf.num_layers == num_layers + 1 assert nf.num_layers == num_layers + 1
assert nf.layer_size(0) == g.number_of_nodes() assert nf.layer_size(0) == g.number_of_nodes()
assert nf.layer_size(1) == g.number_of_nodes() assert nf.layer_size(1) == g.number_of_nodes()
check_basic(g, nf) check_basic(g, nf)
print(0, 2)
parent_nids = F.arange(0, g.number_of_nodes()) parent_nids = F.arange(0, g.number_of_nodes())
nids = dgl.graph_index.map_to_nodeflow_nid(nf._graph, 0, nids = dgl.graph_index.map_to_nodeflow_nid(nf._graph, 0,
utils.toindex(parent_nids)).tousertensor() utils.toindex(parent_nids)).tousertensor()
print(0, 3)
assert F.array_equal(nids, parent_nids) assert F.array_equal(nids, parent_nids)
g = generate_rand_graph(100) g = generate_rand_graph(100)
......
...@@ -98,6 +98,51 @@ def test_10neighbor_sampler(): ...@@ -98,6 +98,51 @@ def test_10neighbor_sampler():
check_10neighbor_sampler(g, seeds=np.unique(np.random.randint(0, g.number_of_nodes(), check_10neighbor_sampler(g, seeds=np.unique(np.random.randint(0, g.number_of_nodes(),
size=int(g.number_of_nodes() / 10)))) size=int(g.number_of_nodes() / 10))))
def test_layer_sampler(prefetch=False):
g = generate_rand_graph(100)
nid = g.nodes()
src, dst, eid = g.all_edges(form='all', order='eid')
n_batches = 5
batch_size = 50
seed_batches = [np.sort(np.random.choice(F.asnumpy(nid), batch_size, replace=False))
for i in range(n_batches)]
seed_nodes = np.hstack(seed_batches)
layer_sizes = [50] * 3
LayerSampler = getattr(dgl.contrib.sampling, 'LayerSampler')
sampler = LayerSampler(g, batch_size, layer_sizes, 'in',
seed_nodes=seed_nodes, num_workers=4, prefetch=prefetch)
for sub_g in sampler:
assert all(sub_g.layer_size(i) < size for i, size in enumerate(layer_sizes))
sub_nid = F.arange(0, sub_g.number_of_nodes())
assert all(np.all(np.isin(F.asnumpy(sub_g.layer_nid(i)), F.asnumpy(sub_nid)))
for i in range(sub_g.num_layers))
assert np.all(np.isin(F.asnumpy(sub_g.map_to_parent_nid(sub_nid)),
F.asnumpy(nid)))
sub_eid = F.arange(0, sub_g.number_of_edges())
assert np.all(np.isin(F.asnumpy(sub_g.map_to_parent_eid(sub_eid)),
F.asnumpy(eid)))
assert any(np.all(np.sort(F.asnumpy(sub_g.layer_parent_nid(-1))) == seed_batch)
for seed_batch in seed_batches)
sub_src, sub_dst = sub_g.all_edges(order='eid')
for i in range(sub_g.num_blocks):
block_eid = sub_g.block_eid(i)
block_src = sub_g.map_to_parent_nid(sub_src[block_eid])
block_dst = sub_g.map_to_parent_nid(sub_dst[block_eid])
block_parent_eid = sub_g.block_parent_eid(i)
block_parent_src = src[block_parent_eid]
block_parent_dst = dst[block_parent_eid]
assert np.all(F.asnumpy(block_src == block_parent_src))
n_layers = sub_g.num_layers
sub_n = sub_g.number_of_nodes()
assert sum(F.shape(sub_g.layer_nid(i))[0] for i in range(n_layers)) == sub_n
n_blocks = sub_g.num_blocks
sub_m = sub_g.number_of_edges()
assert sum(F.shape(sub_g.block_eid(i))[0] for i in range(n_blocks)) == sub_m
def test_random_walk(): def test_random_walk():
edge_list = [(0, 1), (1, 2), (2, 3), (3, 4), edge_list = [(0, 1), (1, 2), (2, 3), (3, 4),
(4, 3), (3, 2), (2, 1), (1, 0)] (4, 3), (3, 2), (2, 1), (1, 0)]
...@@ -123,4 +168,6 @@ if __name__ == '__main__': ...@@ -123,4 +168,6 @@ if __name__ == '__main__':
test_10neighbor_sampler_all() test_10neighbor_sampler_all()
test_1neighbor_sampler() test_1neighbor_sampler()
test_10neighbor_sampler() test_10neighbor_sampler()
test_layer_sampler()
test_layer_sampler(prefetch=True)
test_random_walk() test_random_walk()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment