Unverified Commit 66971c1a authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Feature] Add edge sampling for link prediction (#780)

* add edge sampler.

* add test and run.

* add negative sampling.

* remap the edge subgraph vid.

* negative graph excludes edges of positive edgs.

* remove print.

* avoid sampling NodeFlow when expand_factor or num_hops is 0.

* fix a bug when excluding nodes in negative graph.

* support multigraph.

* exclude positive edges.

* fix memory leak.

* return subgraph object directly.

* fix many problems.

* add comments.

* address comments
parent 2bff8339
from .sampler import NeighborSampler, LayerSampler
from .sampler import NeighborSampler, LayerSampler, EdgeSampler
from .randomwalk import *
from .dis_sampler import SamplerSender, SamplerReceiver
from .dis_sampler import SamplerPool
......@@ -10,32 +10,33 @@ from ..._ffi.function import _init_api
from ... import utils
from ...nodeflow import NodeFlow
from ... import backend as F
from ... import subgraph
try:
import Queue as queue
except ImportError:
import queue
__all__ = ['NeighborSampler', 'LayerSampler']
__all__ = ['NeighborSampler', 'LayerSampler', 'EdgeSampler']
class NodeFlowSamplerIter(object):
class SamplerIter(object):
def __init__(self, sampler):
super(NodeFlowSamplerIter, self).__init__()
super(SamplerIter, self).__init__()
self._sampler = sampler
self._nflows = []
self._nflow_idx = 0
self._batches = []
self._batch_idx = 0
def prefetch(self):
nflows = self._sampler.fetch(self._nflow_idx)
self._nflows.extend(nflows)
self._nflow_idx += len(nflows)
batches = self._sampler.fetch(self._batch_idx)
self._batches.extend(batches)
self._batch_idx += len(batches)
def __next__(self):
if len(self._nflows) == 0:
if len(self._batches) == 0:
self.prefetch()
if len(self._nflows) == 0:
if len(self._batches) == 0:
raise StopIteration
return self._nflows.pop(0)
return self._batches.pop(0)
class PrefetchingWrapper(object):
"""Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation
......@@ -186,7 +187,7 @@ class NodeFlowSampler(object):
raise NotImplementedError
def __iter__(self):
it = NodeFlowSamplerIter(self)
it = SamplerIter(self)
if self._num_prefetch:
return self._prefetching_wrapper_class(it, self._num_prefetch)
else:
......@@ -434,6 +435,152 @@ class LayerSampler(NodeFlowSampler):
nflows = [NodeFlow(self.g, obj) for obj in nfobjs]
return nflows
class EdgeSampler(object):
'''Edge sampler for link prediction.
This samples edges from a given graph. The edges sampled for a batch are
placed in a subgraph before returning. In many link prediction tasks,
negative edges are required to train a model. A negative edge is constructed by
corrupting an existing edge in the graph. The current implementation
support two ways of corrupting an edge: corrupt the head node of
an edge (by randomly selecting a node as the head node), or corrupt
the tail node of an edge. When we corrupt the head node of an edge, we randomly
sample a node from the entire graph as the head node. It's possible the constructed
edge exists in the graph. By default, the implementation doesn't explicitly check
if the sampled negative edge exists in a graph. However, a user can exclude
positive edges from negative edges by specifying 'exclude_positive=True'.
When negative edges are created, a batch of negative edges are also placed
in a subgraph.
Currently, negative_mode only supports only 'head' and 'tail'.
If negative_mode=='head', the negative edges are generated by corrupting
head nodes; otherwise, the tail nodes are corrupted.
Parameters
----------
g : DGLGraph
The DGLGraph where we sample edges.
batch_size : int
The batch size (i.e, the number of edges from the graph)
seed_edges : tensor
A list of edges where we sample from.
shuffle : bool
whether randomly shuffle the list of edges where we sample from.
num_workers : int
The number of workers to sample edges in parallel.
prefetch : bool, optional
If true, prefetch the samples in the next batch. Default: False
negative_mode : string
The method used to construct negative edges. Possible values are 'head', 'tail'.
neg_sample_size : int
The number of negative edges to sample for each edge.
exclude_positive : int
Whether to exclude positive edges from the negative edges.
Class properties
----------------
immutable_only : bool
Whether the sampler only works on immutable graphs.
Subclasses can override this property.
'''
immutable_only = False
def __init__(
self,
g,
batch_size,
seed_edges=None,
shuffle=False,
num_workers=1,
prefetch=False,
negative_mode="",
neg_sample_size=0,
exclude_positive=False):
self._g = g
if self.immutable_only and not g._graph.is_readonly():
raise NotImplementedError("This loader only support read-only graphs.")
self._batch_size = int(batch_size)
if seed_edges is None:
self._seed_edges = F.arange(0, g.number_of_edges())
else:
self._seed_edges = seed_edges
if shuffle:
self._seed_edges = F.rand_shuffle(self._seed_edges)
self._seed_edges = utils.toindex(self._seed_edges)
if prefetch:
self._prefetching_wrapper_class = ThreadPrefetchingWrapper
self._num_prefetch = num_workers * 2 if prefetch else 0
self._num_workers = int(num_workers)
self._negative_mode = negative_mode
self._neg_sample_size = neg_sample_size
self._exclude_positive = exclude_positive
def fetch(self, current_index):
'''
It returns a list of subgraphs if it only samples positive edges.
It returns a list of subgraph pairs if it samples both positive edges
and negative edges.
Parameters
----------
current_index : int
How many batches the sampler has generated so far.
Returns
-------
list[GraphIndex] or list[(GraphIndex, GraphIndex)]
Next "bunch" of edges to be processed.
'''
subgs = _CAPI_UniformEdgeSampling(
self.g._graph,
self.seed_edges.todgltensor(),
current_index, # start batch id
self.batch_size, # batch size
self._num_workers, # num batches
self._negative_mode,
self._neg_sample_size,
self._exclude_positive)
if len(subgs) == 0:
return []
if self._negative_mode == "":
# If no negative subgraphs.
return [subgraph.DGLSubGraph(self.g, subg) for subg in subgs]
else:
rets = []
assert self._num_workers * 2 == len(subgs)
for i in range(self._num_workers):
pos_subg = subgraph.DGLSubGraph(self.g, subgs[i])
neg_subg = subgraph.DGLSubGraph(self.g, subgs[i + self._num_workers])
rets.append((pos_subg, neg_subg))
return rets
def __iter__(self):
it = SamplerIter(self)
if self._num_prefetch:
return self._prefetching_wrapper_class(it, self._num_prefetch)
else:
return it
@property
def g(self):
return self._g
@property
def seed_edges(self):
return self._seed_edges
@property
def batch_size(self):
return self._batch_size
def create_full_nodeflow(g, num_layers, add_self_loop=False):
"""Convert a full graph to NodeFlow to run a L-layer GNN model.
......
......@@ -324,12 +324,12 @@ Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
IdArray new_dst = aten::IndexSelect(adj_.col, eids);
induced_nodes = aten::Relabel_({new_src, new_dst});
const auto new_nnodes = induced_nodes->shape[0];
subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst));
subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst, this->IsMultigraph()));
} else {
IdArray new_src = aten::IndexSelect(adj_.row, eids);
IdArray new_dst = aten::IndexSelect(adj_.col, eids);
induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst));
subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst, this->IsMultigraph()));
}
Subgraph subg;
subg.graph = subcoo;
......
......@@ -115,6 +115,24 @@ void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out) {
out->insert(out->end(), sampled_idxs.begin(), sampled_idxs.end());
}
void RandomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t>* out) {
std::unordered_map<size_t, int> sampled_idxs;
for (auto v : exclude) {
sampled_idxs.insert(std::pair<size_t, int>(v, 0));
}
while (sampled_idxs.size() < num + exclude.size()) {
size_t rand = RandomEngine::ThreadLocal()->RandInt(set_size);
sampled_idxs.insert(std::pair<size_t, int>(rand, 1));
}
out->clear();
for (auto it = sampled_idxs.begin(); it != sampled_idxs.end(); it++) {
if (it->second) {
out->push_back(it->first);
}
}
}
/*
* For a sparse array whose non-zeros are represented by nz_idxs,
* negate the sparse array and outputs the non-zeros in the negated array.
......@@ -865,4 +883,176 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
*rv = List<NodeFlow>(nflows);
});
namespace {
void BuildCoo(const ImmutableGraph &g) {
auto coo = g.GetCOO();
assert(coo);
}
dgl_id_t global2local_map(dgl_id_t global_id,
std::unordered_map<dgl_id_t, dgl_id_t> *map) {
auto it = map->find(global_id);
if (it == map->end()) {
dgl_id_t local_id = map->size();
map->insert(std::pair<dgl_id_t, dgl_id_t>(global_id, local_id));
return local_id;
} else {
return it->second;
}
}
Subgraph NegEdgeSubgraph(int64_t num_tot_nodes, const Subgraph &pos_subg,
const std::string &neg_mode,
int neg_sample_size, bool is_multigraph,
bool exclude_positive) {
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2;
int64_t num_neg_edges = num_pos_edges * neg_sample_size;
IdArray neg_dst = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
IdArray neg_src = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
IdArray neg_eid = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
IdArray induced_neg_eid = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
// These are vids in the positive subgraph.
const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data);
const dgl_id_t *src_data = static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;
const dgl_id_t *induced_vid_data = static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);
const dgl_id_t *induced_eid_data = static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);
size_t num_pos_nodes = pos_subg.graph->NumVertices();
std::vector<size_t> pos_nodes(induced_vid_data, induced_vid_data + num_pos_nodes);
dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data);
dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data);
dgl_id_t *neg_eid_data = static_cast<dgl_id_t *>(neg_eid->data);
dgl_id_t *induced_neg_eid_data = static_cast<dgl_id_t *>(induced_neg_eid->data);
bool neg_head = (neg_mode == "head");
dgl_id_t curr_eid = 0;
std::vector<size_t> neg_vids;
neg_vids.reserve(neg_sample_size);
std::unordered_map<dgl_id_t, dgl_id_t> neg_map;
for (int64_t i = 0; i < num_pos_edges; i++) {
size_t neg_idx = i * neg_sample_size;
neg_vids.clear();
std::vector<size_t> neighbors;
DGLIdIters neigh_it;
const dgl_id_t *unchanged;
dgl_id_t *neg_unchanged;
dgl_id_t *neg_changed;
if (neg_head) {
unchanged = dst_data;
neg_unchanged = neg_dst_data;
neg_changed = neg_src_data;
neigh_it = pos_subg.graph->PredVec(unchanged[i]);
} else {
unchanged = src_data;
neg_unchanged = neg_src_data;
neg_changed = neg_dst_data;
neigh_it = pos_subg.graph->SuccVec(unchanged[i]);
}
if (exclude_positive) {
std::vector<size_t> exclude;
for (auto it = neigh_it.begin(); it != neigh_it.end(); it++) {
dgl_id_t local_vid = *it;
exclude.push_back(induced_vid_data[local_vid]);
}
RandomSample(num_tot_nodes, neg_sample_size, exclude, &neg_vids);
} else {
RandomSample(num_tot_nodes, neg_sample_size, &neg_vids);
}
dgl_id_t global_unchanged = induced_vid_data[unchanged[i]];
dgl_id_t local_unchanged = global2local_map(global_unchanged, &neg_map);
for (int64_t j = 0; j < neg_sample_size; j++) {
neg_unchanged[neg_idx + j] = local_unchanged;
neg_eid_data[neg_idx + j] = curr_eid++;
dgl_id_t local_changed = global2local_map(neg_vids[j], &neg_map);
neg_changed[neg_idx + j] = local_changed;
// induced negative eid references to the positive one.
induced_neg_eid_data[neg_idx + j] = induced_eid_data[i];
}
}
// Now we know the number of vertices in the negative graph.
int64_t num_neg_nodes = neg_map.size();
IdArray induced_neg_vid = IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx);
dgl_id_t *induced_neg_vid_data = static_cast<dgl_id_t *>(induced_neg_vid->data);
for (auto it = neg_map.begin(); it != neg_map.end(); it++) {
induced_neg_vid_data[it->second] = it->first;
}
Subgraph neg_subg;
// We sample negative vertices without replacement.
// There shouldn't be duplicated edges.
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, is_multigraph));
neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));
neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid;
return neg_subg;
}
inline SubgraphRef ConvertRef(const Subgraph &subg) {
return SubgraphRef(std::shared_ptr<Subgraph>(new Subgraph(subg)));
}
} // namespace
DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// arguments
GraphRef g = args[0];
IdArray seed_edges = args[1];
const int64_t batch_start_id = args[2];
const int64_t batch_size = args[3];
const int64_t max_num_workers = args[4];
const std::string neg_mode = args[5];
const int neg_sample_size = args[6];
const bool exclude_positive = args[7];
// process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_edges));
BuildCoo(*gptr);
const int64_t num_seeds = seed_edges->shape[0];
const int64_t num_workers = std::min(max_num_workers,
(num_seeds + batch_size - 1) / batch_size - batch_start_id);
// generate subgraphs.
std::vector<SubgraphRef> positive_subgs(num_workers);
std::vector<SubgraphRef> negative_subgs(num_workers);
#pragma omp parallel for
for (int i = 0; i < num_workers; i++) {
const int64_t start = (batch_start_id + i) * batch_size;
const int64_t end = std::min(start + batch_size, num_seeds);
const int64_t num_edges = end - start;
IdArray worker_seeds = seed_edges.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
sizeof(dgl_id_t) * start);
EdgeArray arr = gptr->FindEdges(worker_seeds);
const dgl_id_t *src_ids = static_cast<const dgl_id_t *>(arr.src->data);
const dgl_id_t *dst_ids = static_cast<const dgl_id_t *>(arr.dst->data);
std::vector<dgl_id_t> src_vec(src_ids, src_ids + num_edges);
std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + num_edges);
// TODO(zhengda) what if there are duplicates in the src and dst vectors.
Subgraph subg = gptr->EdgeSubgraph(worker_seeds, false);
positive_subgs[i] = ConvertRef(subg);
if (neg_mode.size() > 0) {
Subgraph neg_subg = NegEdgeSubgraph(gptr->NumVertices(), subg,
neg_mode, neg_sample_size,
gptr->IsMultigraph(), exclude_positive);
negative_subgs[i] = ConvertRef(neg_subg);
}
}
if (neg_mode.size() > 0) {
positive_subgs.insert(positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());
}
*rv = List<SubgraphRef>(positive_subgs);
});
} // namespace dgl
......@@ -3,6 +3,7 @@ import numpy as np
import scipy as sp
import dgl
from dgl import utils
from numpy.testing import assert_array_equal
np.random.seed(42)
......@@ -219,6 +220,42 @@ def test_setseed():
g, 5, 3, num_hops=2, neighbor_type='in', num_workers=4)):
pass
def test_negative_sampler():
g = generate_rand_graph(100)
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
for pos_edges, neg_edges in EdgeSampler(g, 50,
negative_mode="head",
neg_sample_size=10,
exclude_positive=True):
assert 10 * pos_edges.number_of_edges() == neg_edges.number_of_edges()
pos_nid = pos_edges.parent_nid
pos_eid = pos_edges.parent_eid
pos_lsrc, pos_ldst, pos_leid = pos_edges.all_edges(form='all', order='eid')
pos_src = pos_nid[pos_lsrc]
pos_dst = pos_nid[pos_ldst]
pos_eid = pos_eid[pos_leid]
assert_array_equal(F.asnumpy(pos_eid), F.asnumpy(g.edge_ids(pos_src, pos_dst)))
pos_map = {}
for i in range(len(pos_eid)):
pos_d = int(F.asnumpy(pos_dst[i]))
pos_e = int(F.asnumpy(pos_eid[i]))
pos_map[(pos_d, pos_e)] = int(F.asnumpy(pos_src[i]))
neg_lsrc, neg_ldst, neg_leid = neg_edges.all_edges(form='all', order='eid')
neg_nid = neg_edges.parent_nid
neg_eid = neg_edges.parent_eid
neg_src = neg_nid[neg_lsrc]
neg_dst = neg_nid[neg_ldst]
neg_eid = neg_eid[neg_leid]
for i in range(len(neg_eid)):
neg_d = int(F.asnumpy(neg_dst[i]))
neg_e = int(F.asnumpy(neg_eid[i]))
assert (neg_d, neg_e) in pos_map
assert int(F.asnumpy(neg_src[i])) != pos_map[(neg_d, neg_e)]
if __name__ == '__main__':
test_create_full()
test_1neighbor_sampler_all()
......@@ -228,3 +265,4 @@ if __name__ == '__main__':
test_layer_sampler()
test_nonuniform_neighbor_sampler()
test_setseed()
test_negative_sampler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment