Commit 632a9af8 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Da Zheng
Browse files

[Feature] No-Uniform Edge Sampler (#1087)

* Add weight based edge sampler

* Can run, edge weight work.
TODO: test node weight

* Fix node weight sample

* Fix y

* Update doc

* Fix syntex

* Fix

* Fix GPU test for sampler

* Fix test

* Fix

* Refactor EdgeSampler to act as class object not function that it
can record its own private states.

* clean

* Fix

* Fix

* Fix run bug on kg app

* update

* update test

* test

* Simply python API and fix some C code

* Fix

* Fix

* Fix syntex

* Fix

* Update API description
parent dd65ee21
......@@ -7,6 +7,7 @@ from numbers import Integral
import traceback
from ..._ffi.function import _init_api
from ..._ffi.object import register_object, ObjectBase
from ..._ffi.ndarray import empty
from ... import utils
from ...nodeflow import NodeFlow
......@@ -509,6 +510,16 @@ class EdgeSampler(object):
The sampler returns EdgeSubgraph, where a user can access the unique head nodes
and tail nodes directly.
This sampler allows to non-uniformly sample positive edges and negative edges.
For non-uniformly sampling positive edges, users need to provide an array of m
elements (m is the number of edges), i.e. edge_weight, each of which represents
the sampling probability of an edge. For non-uniformly sampling negative edges,
users need to provide an array of n elements, i.e. node_weight and the sampler
samples nodes based on the sampling probability to corrupt a positive edge. If
both edge_weight and node_weight are not provided, a uniformed sampler is used.
if only edge_weight is provided, the sampler will take uniform sampling when
corrupt positive edges.
When the flag `return_false_neg` is turned on, the sampler will also check
if the generated negative edges are true negative edges and will return
a vector that indicates false negative edges. The vector is stored in
......@@ -519,6 +530,11 @@ class EdgeSampler(object):
edge only if the triple (source node, destination node and relation)
matches one of the edges in the graph.
For uniform sampling, the sampler generates only num_of_edges/batch_size
samples.
For uniform sampling, the sampler generates samples infinitly.
Parameters
----------
g : DGLGraph
......@@ -527,6 +543,11 @@ class EdgeSampler(object):
The batch size (i.e, the number of edges from the graph)
seed_edges : tensor, optional
A list of edges where we sample from.
edge_weight : tensor, optional
The weight of each edge which decide the change of certain edge being sampled.
node_weight : tensor, optional
The weight of each node which decide the change of certain node being sampled.
Used in negative sampling. If not provided, uniform node sampling is used.
shuffle : bool, optional
whether randomly shuffle the list of edges where we sample from.
num_workers : int, optional
......@@ -564,6 +585,8 @@ class EdgeSampler(object):
g,
batch_size,
seed_edges=None,
edge_weight=None,
node_weight=None,
shuffle=False,
num_workers=1,
prefetch=False,
......@@ -596,6 +619,16 @@ class EdgeSampler(object):
self._seed_edges = seed_edges
if shuffle:
self._seed_edges = F.rand_shuffle(self._seed_edges)
if edge_weight is None:
self._is_uniform = True
else:
self._is_uniform = False
self._edge_weight = F.zerocopy_to_dgl_ndarray(edge_weight[self._seed_edges])
if node_weight is None:
self._node_weight = empty((0,), 'float32')
else:
self._node_weight = F.zerocopy_to_dgl_ndarray(node_weight)
self._seed_edges = utils.toindex(self._seed_edges)
if prefetch:
......@@ -606,6 +639,30 @@ class EdgeSampler(object):
self._negative_mode = negative_mode
self._neg_sample_size = neg_sample_size
self._exclude_positive = exclude_positive
if self._is_uniform:
self._sampler = _CAPI_CreateUniformEdgeSampler(
self.g._graph,
self.seed_edges.todgltensor(),
self.batch_size, # batch size
self._num_workers, # num batches
self._negative_mode,
self._neg_sample_size,
self._exclude_positive,
self._return_false_neg,
self._relations)
else:
self._sampler = _CAPI_CreateWeightedEdgeSampler(
self.g._graph,
self._seed_edges.todgltensor(),
self._edge_weight,
self._node_weight,
self._batch_size, # batch size
self._num_workers, # num batches
self._negative_mode,
self._neg_sample_size,
self._exclude_positive,
self._return_false_neg,
self._relations)
def fetch(self, current_index):
'''
......@@ -616,24 +673,19 @@ class EdgeSampler(object):
Parameters
----------
current_index : int
How many batches the sampler has generated so far.
deprecated, not used actually.
Returns
-------
list[GraphIndex] or list[(GraphIndex, GraphIndex)]
Next "bunch" of edges to be processed.
'''
subgs = _CAPI_UniformEdgeSampling(
self.g._graph,
self.seed_edges.todgltensor(),
current_index, # start batch id
self.batch_size, # batch size
self._num_workers, # num batches
self._negative_mode,
self._neg_sample_size,
self._exclude_positive,
self._return_false_neg,
self._relations)
if self._is_uniform:
subgs = _CAPI_FetchUniformEdgeSample(
self._sampler)
else:
subgs = _CAPI_FetchWeightedEdgeSample(
self._sampler)
if len(subgs) == 0:
return []
......@@ -673,7 +725,6 @@ class EdgeSampler(object):
def batch_size(self):
return self._batch_size
def create_full_nodeflow(g, num_layers, add_self_loop=False):
"""Convert a full graph to NodeFlow to run a L-layer GNN model.
......
......@@ -30,16 +30,16 @@ class ArrayHeap {
explicit ArrayHeap(const std::vector<ValueType>& prob) {
vec_size_ = prob.size();
bit_len_ = ceil(log2(vec_size_));
limit_ = 1 << bit_len_;
limit_ = 1UL << bit_len_;
// allocate twice the size
heap_.resize(limit_ << 1, 0);
// allocate the leaves
for (int i = limit_; i < vec_size_+limit_; ++i) {
for (size_t i = limit_; i < vec_size_+limit_; ++i) {
heap_[i] = prob[i-limit_];
}
// iterate up the tree (this is O(m))
for (int i = bit_len_-1; i >= 0; --i) {
for (int j = (1 << i); j < (1 << (i + 1)); ++j) {
for (size_t j = (1UL << i); j < (1UL << (i + 1)); ++j) {
heap_[j] = heap_[j << 1] + heap_[(j << 1) + 1];
}
}
......@@ -74,7 +74,7 @@ class ArrayHeap {
*/
size_t Sample() {
ValueType xi = heap_[1] * RandomEngine::ThreadLocal()->Uniform<float>();
int i = 1;
size_t i = 1;
while (i < limit_) {
i = i << 1;
if (xi >= heap_[i]) {
......@@ -97,12 +97,68 @@ class ArrayHeap {
}
private:
int vec_size_; // sample size
size_t vec_size_; // sample size
int bit_len_; // bit size
int limit_;
size_t limit_;
std::vector<ValueType> heap_;
};
///////////////////////// Samplers //////////////////////////
class EdgeSamplerObject: public Object {
public:
EdgeSamplerObject(const GraphPtr gptr,
IdArray seed_edges,
const int64_t batch_size,
const int64_t num_workers,
const std::string neg_mode,
const int64_t neg_sample_size,
const bool exclude_positive,
const bool check_false_neg,
IdArray relations) {
gptr_ = gptr;
seed_edges_ = seed_edges;
relations_ = relations;
batch_size_ = batch_size;
num_workers_ = num_workers;
neg_mode_ = neg_mode;
neg_sample_size_ = neg_sample_size;
exclude_positive_ = exclude_positive;
check_false_neg_ = check_false_neg;
}
~EdgeSamplerObject() {}
virtual void Fetch(DGLRetValue* rv) = 0;
protected:
virtual void randomSample(size_t set_size, size_t num, std::vector<size_t>* out) = 0;
virtual void randomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t>* out) = 0;
NegSubgraph genNegEdgeSubgraph(const Subgraph &pos_subg,
const std::string &neg_mode,
int64_t neg_sample_size,
bool exclude_positive,
bool check_false_neg);
NegSubgraph genPBGNegEdgeSubgraph(const Subgraph &pos_subg,
const std::string &neg_mode,
int64_t neg_sample_size,
bool exclude_positive,
bool check_false_neg);
GraphPtr gptr_;
IdArray seed_edges_;
IdArray relations_;
int64_t batch_size_;
int64_t num_workers_;
std::string neg_mode_;
int64_t neg_sample_size_;
bool exclude_positive_;
bool check_false_neg_;
};
/*
* Uniformly sample integers from [0, set_size) without replacement.
*/
......@@ -988,14 +1044,15 @@ std::vector<dgl_id_t> Global2Local(const std::vector<size_t> &ids,
return local_ids;
}
NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &pos_subg,
const std::string &neg_mode,
int neg_sample_size, bool exclude_positive,
bool check_false_neg) {
int64_t num_tot_nodes = gptr->NumVertices();
NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
const std::string &neg_mode,
int64_t neg_sample_size,
bool exclude_positive,
bool check_false_neg) {
int64_t num_tot_nodes = gptr_->NumVertices();
if (neg_sample_size > num_tot_nodes)
neg_sample_size = num_tot_nodes;
bool is_multigraph = gptr->IsMultigraph();
bool is_multigraph = gptr_->IsMultigraph();
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2;
......@@ -1008,8 +1065,10 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po
// These are vids in the positive subgraph.
const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data);
const dgl_id_t *src_data = static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;
const dgl_id_t *induced_vid_data = static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);
const dgl_id_t *induced_eid_data = static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);
const dgl_id_t *induced_vid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);
const dgl_id_t *induced_eid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);
size_t num_pos_nodes = pos_subg.graph->NumVertices();
std::vector<size_t> pos_nodes(induced_vid_data, induced_vid_data + num_pos_nodes);
......@@ -1076,9 +1135,9 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po
std::vector<size_t> neighbors;
DGLIdIters neigh_it;
if (IsNegativeHeadMode(neg_mode)) {
neigh_it = gptr->PredVec(induced_vid_data[unchanged[i]]);
neigh_it = gptr_->PredVec(induced_vid_data[unchanged[i]]);
} else {
neigh_it = gptr->SuccVec(induced_vid_data[unchanged[i]]);
neigh_it = gptr_->SuccVec(induced_vid_data[unchanged[i]]);
}
// If the number of negative nodes is smaller than the number of total nodes
......@@ -1090,14 +1149,15 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po
exclude.push_back(global_vid);
}
prev_neg_offset = neg_vids.size();
RandomSample(num_tot_nodes, neg_sample_size, exclude, &neg_vids);
randomSample(num_tot_nodes, neg_sample_size, exclude, &neg_vids);
assert(prev_neg_offset + neg_sample_size == neg_vids.size());
} else if (neg_sample_size < num_tot_nodes) {
prev_neg_offset = neg_vids.size();
RandomSample(num_tot_nodes, neg_sample_size, &neg_vids);
randomSample(num_tot_nodes, neg_sample_size, &neg_vids);
assert(prev_neg_offset + neg_sample_size == neg_vids.size());
} else if (exclude_positive) {
LOG(FATAL) << "We can't exclude positive edges when sampling negative edges with all nodes.";
LOG(FATAL) << "We can't exclude positive edges"
"when sampling negative edges with all nodes.";
} else {
// We don't need to do anything here.
// In this case, every edge has the same negative edges. That is,
......@@ -1149,21 +1209,22 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po
}
// TODO(zhengda) we should provide an array of 1s if exclude_positive
if (check_false_neg) {
if (relations->shape[0] == 0) {
neg_subg.exist = CheckExistence(gptr, neg_src, neg_dst, induced_neg_vid);
if (relations_->shape[0] == 0) {
neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid);
} else {
neg_subg.exist = CheckExistence(gptr, relations, neg_src, neg_dst,
neg_subg.exist = CheckExistence(gptr_, relations_, neg_src, neg_dst,
induced_neg_vid, induced_neg_eid);
}
}
return neg_subg;
}
NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &pos_subg,
const std::string &neg_mode,
int neg_sample_size, bool is_multigraph,
bool exclude_positive, bool check_false_neg) {
int64_t num_tot_nodes = gptr->NumVertices();
NegSubgraph EdgeSamplerObject::genPBGNegEdgeSubgraph(const Subgraph &pos_subg,
const std::string &neg_mode,
int64_t neg_sample_size,
bool exclude_positive,
bool check_false_neg) {
int64_t num_tot_nodes = gptr_->NumVertices();
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2;
......@@ -1195,10 +1256,12 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph
// These are vids in the positive subgraph.
const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data);
const dgl_id_t *src_data = static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;
const dgl_id_t *induced_vid_data = static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);
const dgl_id_t *induced_eid_data = static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);
size_t num_pos_nodes = pos_subg.graph->NumVertices();
std::vector<size_t> pos_nodes(induced_vid_data, induced_vid_data + num_pos_nodes);
const dgl_id_t *induced_vid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);
const dgl_id_t *induced_eid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);
int64_t num_pos_nodes = pos_subg.graph->NumVertices();
std::vector<dgl_id_t> pos_nodes(induced_vid_data, induced_vid_data + num_pos_nodes);
dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data);
dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data);
......@@ -1208,14 +1271,11 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph
const dgl_id_t *unchanged;
dgl_id_t *neg_unchanged;
dgl_id_t *neg_changed;
// corrupt head nodes.
if (IsNegativeHeadMode(neg_mode)) {
unchanged = dst_data;
neg_unchanged = neg_dst_data;
neg_changed = neg_src_data;
} else {
// corrupt tail nodes.
unchanged = src_data;
neg_unchanged = neg_src_data;
neg_changed = neg_dst_data;
......@@ -1223,13 +1283,14 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph
// We first sample all negative edges.
std::vector<size_t> neg_vids;
RandomSample(num_tot_nodes,
randomSample(num_tot_nodes,
num_chunks * neg_sample_size,
&neg_vids);
dgl_id_t curr_eid = 0;
std::unordered_map<dgl_id_t, dgl_id_t> neg_map;
dgl_id_t local_vid = 0;
// Collect nodes in the positive side.
std::vector<dgl_id_t> local_pos_vids;
local_pos_vids.reserve(num_pos_edges);
......@@ -1256,7 +1317,6 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph
for (int64_t in_chunk = 0; in_chunk != chunk_size1; ++in_chunk) {
// For each positive node in a chunk.
dgl_id_t global_unchanged = induced_vid_data[unchanged[pos_edge_idx + in_chunk]];
dgl_id_t local_unchanged = global2local_map(global_unchanged, &neg_map);
for (int64_t j = 0; j < neg_sample_size; ++j) {
......@@ -1284,7 +1344,7 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph
NegSubgraph neg_subg;
// We sample negative vertices without replacement.
// There shouldn't be duplicated edges.
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, is_multigraph));
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, gptr_->IsMultigraph()));
neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));
neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid;
......@@ -1296,10 +1356,10 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph
neg_subg.tail_nid = aten::VecToIdArray(Global2Local(neg_vids, neg_map));
}
if (check_false_neg) {
if (relations->shape[0] == 0) {
neg_subg.exist = CheckExistence(gptr, neg_src, neg_dst, induced_neg_vid);
if (relations_->shape[0] == 0) {
neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid);
} else {
neg_subg.exist = CheckExistence(gptr, relations, neg_src, neg_dst,
neg_subg.exist = CheckExistence(gptr_, relations_, neg_src, neg_dst,
induced_neg_vid, induced_neg_eid);
}
}
......@@ -1316,87 +1376,398 @@ inline SubgraphRef ConvertRef(const NegSubgraph &subg) {
} // namespace
DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
DGL_REGISTER_GLOBAL("sampling._CAPI_GetNegEdgeExistence")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// arguments
GraphRef g = args[0];
IdArray seed_edges = args[1];
const int64_t batch_start_id = args[2];
const int64_t batch_size = args[3];
const int64_t max_num_workers = args[4];
const std::string neg_mode = args[5];
const int neg_sample_size = args[6];
const bool exclude_positive = args[7];
const bool check_false_neg = args[8];
IdArray relations = args[9];
// process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_edges));
BuildCoo(*gptr);
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->exist;
});
const int64_t num_seeds = seed_edges->shape[0];
const int64_t num_workers = std::min(max_num_workers,
(num_seeds + batch_size - 1) / batch_size - batch_start_id);
DGL_REGISTER_GLOBAL("sampling._CAPI_GetEdgeSubgraphHead")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->head_nid;
});
DGL_REGISTER_GLOBAL("sampling._CAPI_GetEdgeSubgraphTail")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->tail_nid;
});
class UniformEdgeSamplerObject: public EdgeSamplerObject {
public:
explicit UniformEdgeSamplerObject(const GraphPtr gptr,
IdArray seed_edges,
const int64_t batch_size,
const int64_t num_workers,
const std::string neg_mode,
const int64_t neg_sample_size,
const bool exclude_positive,
const bool check_false_neg,
IdArray relations)
: EdgeSamplerObject(gptr,
seed_edges,
batch_size,
num_workers,
neg_mode,
neg_sample_size,
exclude_positive,
check_false_neg,
relations) {
batch_curr_id_ = 0;
num_seeds_ = seed_edges->shape[0];
max_batch_id_ = (num_seeds_ + batch_size - 1) / batch_size;
// TODO(song): Tricky thing here to make sure gptr_ has coo cache
gptr_->FindEdge(0);
}
~UniformEdgeSamplerObject() {}
void Fetch(DGLRetValue* rv) {
const int64_t num_workers = std::min(num_workers_, max_batch_id_ - batch_curr_id_);
// generate subgraphs.
std::vector<SubgraphRef> positive_subgs(num_workers);
std::vector<SubgraphRef> negative_subgs(num_workers);
#pragma omp parallel for
for (int i = 0; i < num_workers; i++) {
const int64_t start = (batch_start_id + i) * batch_size;
const int64_t end = std::min(start + batch_size, num_seeds);
for (int64_t i = 0; i < num_workers; i++) {
const int64_t start = (batch_curr_id_ + i) * batch_size_;
const int64_t end = std::min(start + batch_size_, num_seeds_);
const int64_t num_edges = end - start;
IdArray worker_seeds = seed_edges.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
IdArray worker_seeds = seed_edges_.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
sizeof(dgl_id_t) * start);
EdgeArray arr = gptr->FindEdges(worker_seeds);
EdgeArray arr = gptr_->FindEdges(worker_seeds);
const dgl_id_t *src_ids = static_cast<const dgl_id_t *>(arr.src->data);
const dgl_id_t *dst_ids = static_cast<const dgl_id_t *>(arr.dst->data);
std::vector<dgl_id_t> src_vec(src_ids, src_ids + num_edges);
std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + num_edges);
// TODO(zhengda) what if there are duplicates in the src and dst vectors.
Subgraph subg = gptr->EdgeSubgraph(worker_seeds, false);
Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false);
positive_subgs[i] = ConvertRef(subg);
// For PBG negative sampling, we accept "PBG-head" for corrupting head
// nodes and "PBG-tail" for corrupting tail nodes.
if (neg_mode.substr(0, 3) == "PBG") {
NegSubgraph neg_subg = PBGNegEdgeSubgraph(gptr, relations, subg,
neg_mode.substr(4), neg_sample_size,
gptr->IsMultigraph(), exclude_positive,
check_false_neg);
if (neg_mode_.substr(0, 3) == "PBG") {
NegSubgraph neg_subg = genPBGNegEdgeSubgraph(subg, neg_mode_.substr(4),
neg_sample_size_,
exclude_positive_,
check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg);
} else if (neg_mode.size() > 0) {
NegSubgraph neg_subg = NegEdgeSubgraph(gptr, relations, subg, neg_mode, neg_sample_size,
exclude_positive, check_false_neg);
} else if (neg_mode_.size() > 0) {
NegSubgraph neg_subg = genNegEdgeSubgraph(subg, neg_mode_,
neg_sample_size_,
exclude_positive_,
check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg);
}
}
if (neg_mode.size() > 0) {
if (neg_mode_.size() > 0) {
positive_subgs.insert(positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());
}
batch_curr_id_ += num_workers;
*rv = List<SubgraphRef>(positive_subgs);
});
}
DGL_DECLARE_OBJECT_TYPE_INFO(UniformEdgeSamplerObject, Object);
DGL_REGISTER_GLOBAL("sampling._CAPI_GetNegEdgeExistence")
private:
void randomSample(size_t set_size, size_t num, std::vector<size_t>* out) {
RandomSample(set_size, num, out);
}
void randomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t>* out) {
RandomSample(set_size, num, exclude, out);
}
int64_t batch_curr_id_;
int64_t max_batch_id_;
int64_t num_seeds_;
};
class UniformEdgeSampler: public ObjectRef {
public:
UniformEdgeSampler() {}
explicit UniformEdgeSampler(std::shared_ptr<runtime::Object> obj): ObjectRef(obj) {}
UniformEdgeSamplerObject* operator->() const {
return static_cast<UniformEdgeSamplerObject*>(obj_.get());
}
std::shared_ptr<UniformEdgeSamplerObject> sptr() const {
return CHECK_NOTNULL(std::dynamic_pointer_cast<UniformEdgeSamplerObject>(obj_));
}
operator bool() const { return this->defined(); }
using ContainerType = UniformEdgeSamplerObject;
};
DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->exist;
// arguments
GraphRef g = args[0];
IdArray seed_edges = args[1];
const int64_t batch_size = args[2];
const int64_t max_num_workers = args[3];
const std::string neg_mode = args[4];
const int neg_sample_size = args[5];
const bool exclude_positive = args[6];
const bool check_false_neg = args[7];
IdArray relations = args[8];
// process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_edges));
BuildCoo(*gptr);
auto o = std::make_shared<UniformEdgeSamplerObject>(gptr,
seed_edges,
batch_size,
max_num_workers,
neg_mode,
neg_sample_size,
exclude_positive,
check_false_neg,
relations);
*rv = o;
});
DGL_REGISTER_GLOBAL("sampling._CAPI_GetEdgeSubgraphHead")
DGL_REGISTER_GLOBAL("sampling._CAPI_FetchUniformEdgeSample")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->head_nid;
UniformEdgeSampler sampler = args[0];
sampler->Fetch(rv);
});
DGL_REGISTER_GLOBAL("sampling._CAPI_GetEdgeSubgraphTail")
template<typename ValueType>
class WeightedEdgeSamplerObject: public EdgeSamplerObject {
public:
explicit WeightedEdgeSamplerObject(const GraphPtr gptr,
IdArray seed_edges,
NDArray edge_weight,
NDArray node_weight,
const int64_t batch_size,
const int64_t num_workers,
const std::string neg_mode,
const int64_t neg_sample_size,
const bool exclude_positive,
const bool check_false_neg,
IdArray relations)
: EdgeSamplerObject(gptr,
seed_edges,
batch_size,
num_workers,
neg_mode,
neg_sample_size,
exclude_positive,
check_false_neg,
relations) {
const size_t num_edges = edge_weight->shape[0];
const ValueType *edge_prob = static_cast<const ValueType*>(edge_weight->data);
std::vector<ValueType> eprob(num_edges);
for (size_t i = 0; i < num_edges; ++i) {
eprob[i] = edge_prob[i];
}
edge_selector_ = std::make_shared<ArrayHeap<ValueType>>(eprob);
const size_t num_nodes = node_weight->shape[0];
if (num_nodes == 0) {
node_selector_ = nullptr;
} else {
const ValueType *node_prob = static_cast<const ValueType*>(node_weight->data);
std::vector<ValueType> nprob(num_nodes);
for (size_t i = 0; i < num_nodes; ++i) {
nprob[i] = node_prob[i];
}
node_selector_ = std::make_shared<ArrayHeap<ValueType>>(nprob);
}
// TODO(song): Tricky thing here to make sure gptr_ has coo cache
gptr_->FindEdge(0);
}
~WeightedEdgeSamplerObject() {
}
void Fetch(DGLRetValue* rv) {
// generate subgraphs.
std::vector<SubgraphRef> positive_subgs(num_workers_);
std::vector<SubgraphRef> negative_subgs(num_workers_);
#pragma omp parallel for
for (int i = 0; i < num_workers_; i++) {
const dgl_id_t *seed_edge_ids = static_cast<const dgl_id_t *>(seed_edges_->data);
std::vector<int64_t> edge_ids(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
int64_t edge_id = edge_selector_->Sample();
edge_ids[i] = seed_edge_ids[edge_id];
}
auto worker_seeds = aten::VecToIdArray(edge_ids, seed_edges_->dtype.bits);
EdgeArray arr = gptr_->FindEdges(worker_seeds);
const dgl_id_t *src_ids = static_cast<const dgl_id_t *>(arr.src->data);
const dgl_id_t *dst_ids = static_cast<const dgl_id_t *>(arr.dst->data);
std::vector<dgl_id_t> src_vec(src_ids, src_ids + batch_size_);
std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + batch_size_);
// TODO(zhengda) what if there are duplicates in the src and dst vectors.
Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false);
positive_subgs[i] = ConvertRef(subg);
// For PBG negative sampling, we accept "PBG-head" for corrupting head
// nodes and "PBG-tail" for corrupting tail nodes.
if (neg_mode_.substr(0, 3) == "PBG") {
NegSubgraph neg_subg = genPBGNegEdgeSubgraph(subg, neg_mode_.substr(4),
neg_sample_size_,
exclude_positive_,
check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg);
} else if (neg_mode_.size() > 0) {
NegSubgraph neg_subg = genNegEdgeSubgraph(subg, neg_mode_,
neg_sample_size_,
exclude_positive_,
check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg);
}
}
if (neg_mode_.size() > 0) {
positive_subgs.insert(positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());
}
*rv = List<SubgraphRef>(positive_subgs);
}
DGL_DECLARE_OBJECT_TYPE_INFO(WeightedEdgeSamplerObject<ValueType>, Object);
private:
void randomSample(size_t set_size, size_t num, std::vector<size_t>* out) {
if (num < set_size) {
std::unordered_set<size_t> sampled_idxs;
while (sampled_idxs.size() < num) {
if (node_selector_ == nullptr) {
sampled_idxs.insert(RandomEngine::ThreadLocal()->RandInt(set_size));
} else {
size_t id = node_selector_->Sample();
sampled_idxs.insert(id);
}
}
out->insert(out->end(), sampled_idxs.begin(), sampled_idxs.end());
} else {
// If we need to sample all elements in the set, we don't need to
// generate random numbers.
for (size_t i = 0; i < set_size; i++)
out->push_back(i);
}
}
void randomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t>* out) {
std::unordered_map<size_t, int> sampled_idxs;
for (auto v : exclude) {
sampled_idxs.insert(std::pair<size_t, int>(v, 0));
}
if (num + exclude.size() < set_size) {
while (sampled_idxs.size() < num + exclude.size()) {
size_t rand;
if (node_selector_ == nullptr) {
rand = RandomEngine::ThreadLocal()->RandInt(set_size);
} else {
rand = node_selector_->Sample();
}
sampled_idxs.insert(std::pair<size_t, int>(rand, 1));
}
for (auto it = sampled_idxs.begin(); it != sampled_idxs.end(); it++) {
if (it->second) {
out->push_back(it->first);
}
}
} else {
// If we need to sample all elements in the set, we don't need to
// generate random numbers.
for (size_t i = 0; i < set_size; i++) {
// If the element doesn't exist in exclude.
if (sampled_idxs.find(i) == sampled_idxs.end()) {
out->push_back(i);
}
}
}
}
private:
std::shared_ptr<ArrayHeap<ValueType>> edge_selector_;
std::shared_ptr<ArrayHeap<ValueType>> node_selector_;
};
template class WeightedEdgeSamplerObject<float>;
class FloatWeightedEdgeSampler: public ObjectRef {
public:
FloatWeightedEdgeSampler() {}
explicit FloatWeightedEdgeSampler(std::shared_ptr<runtime::Object> obj): ObjectRef(obj) {}
WeightedEdgeSamplerObject<float>* operator->() const {
return static_cast<WeightedEdgeSamplerObject<float>*>(obj_.get());
}
std::shared_ptr<WeightedEdgeSamplerObject<float>> sptr() const {
return CHECK_NOTNULL(std::dynamic_pointer_cast<WeightedEdgeSamplerObject<float>>(obj_));
}
operator bool() const { return this->defined(); }
using ContainerType = WeightedEdgeSamplerObject<float>;
};
DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->tail_nid;
// arguments
GraphRef g = args[0];
IdArray seed_edges = args[1];
NDArray edge_weight = args[2];
NDArray node_weight = args[3];
const int64_t batch_size = args[4];
const int64_t max_num_workers = args[5];
const std::string neg_mode = args[6];
const int64_t neg_sample_size = args[7];
const bool exclude_positive = args[8];
const bool check_false_neg = args[9];
IdArray relations = args[10];
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_edges));
CHECK(edge_weight->dtype.code == kDLFloat) << "edge_weight should be FloatType";
CHECK(edge_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight";
if (node_weight->shape[0] > 0) {
CHECK(node_weight->dtype.code == kDLFloat) << "node_weight should be FloatType";
CHECK(node_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight";
}
BuildCoo(*gptr);
const int64_t num_seeds = seed_edges->shape[0];
const int64_t num_workers = std::min(max_num_workers,
(num_seeds + batch_size - 1) / batch_size);
auto o = std::make_shared<WeightedEdgeSamplerObject<float>>(gptr,
seed_edges,
edge_weight,
node_weight,
batch_size,
num_workers,
neg_mode,
neg_sample_size,
exclude_positive,
check_false_neg,
relations);
*rv = o;
});
DGL_REGISTER_GLOBAL("sampling._CAPI_FetchWeightedEdgeSample")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
FloatWeightedEdgeSampler sampler = args[0];
sampler->Fetch(rv);
});
} // namespace dgl
......@@ -237,6 +237,7 @@ def check_head_tail(g):
def check_negative_sampler(mode, exclude_positive, neg_size):
g = generate_rand_graph(100)
num_edges = g.number_of_edges()
etype = np.random.randint(0, 10, size=g.number_of_edges(), dtype=np.int64)
g.edata['etype'] = F.copy_to(F.tensor(etype), F.cpu())
......@@ -249,7 +250,10 @@ def check_negative_sampler(mode, exclude_positive, neg_size):
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
# Test the homogeneous graph.
for pos_edges, neg_edges in EdgeSampler(g, 50,
total_samples = 0
batch_size = 50
max_samples = num_edges
for pos_edges, neg_edges in EdgeSampler(g, batch_size,
negative_mode=mode,
neg_sample_size=neg_size,
exclude_positive=exclude_positive,
......@@ -284,8 +288,13 @@ def check_negative_sampler(mode, exclude_positive, neg_size):
else:
assert F.array_equal(g.has_edges_between(neg_src, neg_dst), exist)
total_samples += batch_size
if (total_samples >= max_samples):
break
# Test the knowledge graph.
for _, neg_edges in EdgeSampler(g, 50,
total_samples = 0
for _, neg_edges in EdgeSampler(g, batch_size,
negative_mode=mode,
neg_sample_size=neg_size,
exclude_positive=exclude_positive,
......@@ -304,12 +313,223 @@ def check_negative_sampler(mode, exclude_positive, neg_size):
etype = g.edata['etype'][eid]
exist = neg_edges.edata['etype'][i] == etype
assert F.asnumpy(exists[i]) == F.asnumpy(exist)
total_samples += batch_size
if (total_samples >= max_samples):
break
def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
g = generate_rand_graph(100)
num_edges = g.number_of_edges()
num_nodes = g.number_of_nodes()
edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 1, dtype=np.float32)), F.cpu())
node_weight = F.copy_to(F.tensor(np.full((num_nodes,), 1, dtype=np.float32)), F.cpu())
etype = np.random.randint(0, 10, size=num_edges, dtype=np.int64)
g.edata['etype'] = F.copy_to(F.tensor(etype), F.cpu())
pos_gsrc, pos_gdst, pos_geid = g.all_edges(form='all', order='eid')
pos_map = {}
for i in range(len(pos_geid)):
pos_d = int(F.asnumpy(pos_gdst[i]))
pos_e = int(F.asnumpy(pos_geid[i]))
pos_map[(pos_d, pos_e)] = int(F.asnumpy(pos_gsrc[i]))
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
# Correctness check
# Test the homogeneous graph.
batch_size = 50
total_samples = 0
max_samples = num_edges
for pos_edges, neg_edges in EdgeSampler(g, batch_size,
edge_weight=edge_weight,
negative_mode=mode,
neg_sample_size=neg_size,
exclude_positive=exclude_positive,
return_false_neg=True):
pos_lsrc, pos_ldst, pos_leid = pos_edges.all_edges(form='all', order='eid')
assert_array_equal(F.asnumpy(pos_edges.parent_eid[pos_leid]),
F.asnumpy(g.edge_ids(pos_edges.parent_nid[pos_lsrc],
pos_edges.parent_nid[pos_ldst])))
neg_lsrc, neg_ldst, neg_leid = neg_edges.all_edges(form='all', order='eid')
neg_src = neg_edges.parent_nid[neg_lsrc]
neg_dst = neg_edges.parent_nid[neg_ldst]
neg_eid = neg_edges.parent_eid[neg_leid]
for i in range(len(neg_eid)):
neg_d = int(F.asnumpy(neg_dst[i]))
neg_e = int(F.asnumpy(neg_eid[i]))
assert (neg_d, neg_e) in pos_map
if exclude_positive:
assert int(F.asnumpy(neg_src[i])) != pos_map[(neg_d, neg_e)]
check_head_tail(neg_edges)
pos_tails = pos_edges.parent_nid[pos_edges.tail_nid]
neg_tails = neg_edges.parent_nid[neg_edges.tail_nid]
pos_tails = np.sort(F.asnumpy(pos_tails))
neg_tails = np.sort(F.asnumpy(neg_tails))
np.testing.assert_equal(pos_tails, neg_tails)
exist = neg_edges.edata['false_neg']
if exclude_positive:
assert np.sum(F.asnumpy(exist) == 0) == len(exist)
else:
assert F.array_equal(g.has_edges_between(neg_src, neg_dst), exist)
total_samples += batch_size
if (total_samples >= max_samples):
break
# Test the knowledge graph with edge weight provied.
total_samples = 0
for pos_edges, neg_edges in EdgeSampler(g, batch_size,
edge_weight=edge_weight,
negative_mode=mode,
neg_sample_size=neg_size,
exclude_positive=exclude_positive,
relations=g.edata['etype'],
return_false_neg=True):
neg_lsrc, neg_ldst, neg_leid = neg_edges.all_edges(form='all', order='eid')
neg_src = neg_edges.parent_nid[neg_lsrc]
neg_dst = neg_edges.parent_nid[neg_ldst]
neg_eid = neg_edges.parent_eid[neg_leid]
exists = neg_edges.edata['false_neg']
neg_edges.edata['etype'] = g.edata['etype'][neg_eid]
for i in range(len(neg_eid)):
u, v = F.asnumpy(neg_src[i]), F.asnumpy(neg_dst[i])
if g.has_edge_between(u, v):
eid = g.edge_id(u, v)
etype = g.edata['etype'][eid]
exist = neg_edges.edata['etype'][i] == etype
assert F.asnumpy(exists[i]) == F.asnumpy(exist)
total_samples += batch_size
if (total_samples >= max_samples):
break
# Test the knowledge graph with edge/node weight provied.
total_samples = 0
for pos_edges, neg_edges in EdgeSampler(g, batch_size,
edge_weight=edge_weight,
node_weight=node_weight,
negative_mode=mode,
neg_sample_size=neg_size,
exclude_positive=exclude_positive,
relations=g.edata['etype'],
return_false_neg=True):
neg_lsrc, neg_ldst, neg_leid = neg_edges.all_edges(form='all', order='eid')
neg_src = neg_edges.parent_nid[neg_lsrc]
neg_dst = neg_edges.parent_nid[neg_ldst]
neg_eid = neg_edges.parent_eid[neg_leid]
exists = neg_edges.edata['false_neg']
neg_edges.edata['etype'] = g.edata['etype'][neg_eid]
for i in range(len(neg_eid)):
u, v = F.asnumpy(neg_src[i]), F.asnumpy(neg_dst[i])
if g.has_edge_between(u, v):
eid = g.edge_id(u, v)
etype = g.edata['etype'][eid]
exist = neg_edges.edata['etype'][i] == etype
assert F.asnumpy(exists[i]) == F.asnumpy(exist)
total_samples += batch_size
if (total_samples >= max_samples):
break
# Check Rate
dgl.random.seed(0)
g = generate_rand_graph(1000)
num_edges = g.number_of_edges()
num_nodes = g.number_of_nodes()
edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 1, dtype=np.float32)), F.cpu())
edge_weight[0] = F.sum(edge_weight, dim=0)
node_weight = F.copy_to(F.tensor(np.full((num_nodes,), 1, dtype=np.float32)), F.cpu())
node_weight[-1] = F.sum(node_weight, dim=0) / 200
etype = np.random.randint(0, 20, size=num_edges, dtype=np.int64)
g.edata['etype'] = F.copy_to(F.tensor(etype), F.cpu())
# Test w/o node weight.
max_samples = num_edges / 5
# Test the knowledge graph with edge weight provied.
total_samples = 0
edge_sampled = np.full((num_edges,), 0, dtype=np.int32)
node_sampled = np.full((num_nodes,), 0, dtype=np.int32)
for pos_edges, neg_edges in EdgeSampler(g, batch_size,
edge_weight=edge_weight,
negative_mode=mode,
neg_sample_size=neg_size,
exclude_positive=False,
relations=g.edata['etype'],
return_false_neg=True):
_, _, pos_leid = pos_edges.all_edges(form='all', order='eid')
neg_lsrc, neg_ldst, _ = neg_edges.all_edges(form='all', order='eid')
if 'head' in mode:
neg_src = neg_edges.parent_nid[neg_lsrc]
np.add.at(node_sampled, F.asnumpy(neg_src), 1)
else:
neg_dst = neg_edges.parent_nid[neg_ldst]
np.add.at(node_sampled, F.asnumpy(neg_dst), 1)
np.add.at(edge_sampled, F.asnumpy(pos_edges.parent_eid[pos_leid]), 1)
total_samples += batch_size
if (total_samples >= max_samples):
break
# Check rate here
edge_rate_0 = edge_sampled[0] / edge_sampled.sum()
edge_tail_half_cnt = edge_sampled[edge_sampled.shape[0] // 2:-1].sum()
edge_rate_tail_half = edge_tail_half_cnt / edge_sampled.sum()
assert np.allclose(edge_rate_0, 0.5, atol=0.05)
assert np.allclose(edge_rate_tail_half, 0.25, atol=0.05)
node_rate_0 = node_sampled[0] / node_sampled.sum()
node_tail_half_cnt = node_sampled[node_sampled.shape[0] // 2:-1].sum()
node_rate_tail_half = node_tail_half_cnt / node_sampled.sum()
assert node_rate_0 < 0.02
assert np.allclose(node_rate_tail_half, 0.5, atol=0.02)
# Test the knowledge graph with edge/node weight provied.
total_samples = 0
edge_sampled = np.full((num_edges,), 0, dtype=np.int32)
node_sampled = np.full((num_nodes,), 0, dtype=np.int32)
for pos_edges, neg_edges in EdgeSampler(g, batch_size,
edge_weight=edge_weight,
node_weight=node_weight,
negative_mode=mode,
neg_sample_size=neg_size,
exclude_positive=False,
relations=g.edata['etype'],
return_false_neg=True):
_, _, pos_leid = pos_edges.all_edges(form='all', order='eid')
neg_lsrc, neg_ldst, _ = neg_edges.all_edges(form='all', order='eid')
if 'head' in mode:
neg_src = neg_edges.parent_nid[neg_lsrc]
np.add.at(node_sampled, F.asnumpy(neg_src), 1)
else:
neg_dst = neg_edges.parent_nid[neg_ldst]
np.add.at(node_sampled, F.asnumpy(neg_dst), 1)
np.add.at(edge_sampled, F.asnumpy(pos_edges.parent_eid[pos_leid]), 1)
total_samples += batch_size
if (total_samples >= max_samples):
break
# Check rate here
edge_rate_0 = edge_sampled[0] / edge_sampled.sum()
edge_tail_half_cnt = edge_sampled[edge_sampled.shape[0] // 2:-1].sum()
edge_rate_tail_half = edge_tail_half_cnt / edge_sampled.sum()
assert np.allclose(edge_rate_0, 0.5, atol=0.05)
assert np.allclose(edge_rate_tail_half, 0.25, atol=0.05)
node_rate = node_sampled[-1] / node_sampled.sum()
node_rate_a = np.average(node_sampled[:50]) / node_sampled.sum()
node_rate_b = np.average(node_sampled[50:100]) / node_sampled.sum()
# As neg sampling does not contain duplicate nodes,
# this test takes some acceptable variation on the sample rate.
assert np.allclose(node_rate, node_rate_a * 5, atol=0.002)
assert np.allclose(node_rate_a, node_rate_b, atol=0.0002)
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="Core dump")
def test_negative_sampler():
check_negative_sampler('PBG-head', False, 10)
check_negative_sampler('head', True, 10)
check_negative_sampler('head', False, 10)
check_weighted_negative_sampler('PBG-head', False, 10)
check_weighted_negative_sampler('head', True, 10)
check_weighted_negative_sampler('head', False, 10)
#disable this check for now. It might take too long time.
#check_negative_sampler('head', False, 100)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment