Commit 61b78e6e authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Da Zheng
Browse files

[Bug Fix]edge sample hotfix (#1152)

* hot fix

* Fix docs

* Fix ArrayHeap float overflow bug

* Fix

* Clean some dead code

* Fix

* FIx

* Add some comments

* run test
parent 6731ea3a
...@@ -542,6 +542,10 @@ class EdgeSampler(object): ...@@ -542,6 +542,10 @@ class EdgeSampler(object):
edges and reset the replacement state. If it is set to false, the sampler will only edges and reset the replacement state. If it is set to false, the sampler will only
generate num_edges/batch_size samples. generate num_edges/batch_size samples.
Note: If node_weight is extremely imbalanced, the sampler will take much longer
time to return a minibatch, as sampled negative nodes must not be duplicated for
one corruptted positive edge.
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
...@@ -737,11 +741,9 @@ class EdgeSampler(object): ...@@ -737,11 +741,9 @@ class EdgeSampler(object):
def __iter__(self): def __iter__(self):
it = SamplerIter(self) it = SamplerIter(self)
if self._is_uniform: if self._is_uniform:
subgs = _CAPI_ResetUniformEdgeSample( _CAPI_ResetUniformEdgeSample(self._sampler)
self._sampler)
else: else:
subgs = _CAPI_ResetWeightedEdgeSample( _CAPI_ResetWeightedEdgeSample(self._sampler)
self._sampler)
if self._num_prefetch: if self._num_prefetch:
return self._prefetching_wrapper_class(it, self._num_prefetch) return self._prefetching_wrapper_class(it, self._num_prefetch)
......
...@@ -51,10 +51,13 @@ class ArrayHeap { ...@@ -51,10 +51,13 @@ class ArrayHeap {
*/ */
void Delete(size_t index) { void Delete(size_t index) {
size_t i = index + limit_; size_t i = index + limit_;
ValueType w = heap_[i]; heap_[i] = 0;
for (int j = bit_len_; j >= 0; --j) { i /= 2;
heap_[i] -= w; for (int j = bit_len_-1; j >= 0; --j) {
i = i >> 1; // Using heap_[i] = heap_[i] - w will loss some precision in float.
// Using addition to re-calculate the weight layer by layer.
heap_[i] = heap_[i << 1] + heap_[(i << 1) + 1];
i /= 2;
} }
} }
...@@ -1480,12 +1483,15 @@ public: ...@@ -1480,12 +1483,15 @@ public:
sizeof(dgl_id_t) * start); sizeof(dgl_id_t) * start);
} else { } else {
std::vector<dgl_id_t> seeds; std::vector<dgl_id_t> seeds;
const dgl_id_t *seed_edge_ids = static_cast<const dgl_id_t *>(seed_edges_->data);
// sampling of each edge is a standalone event // sampling of each edge is a standalone event
for (int64_t i = 0; i < num_edges; ++i) { for (int64_t i = 0; i < num_edges; ++i) {
seeds.push_back(RandomEngine::ThreadLocal()->RandInt(num_seeds_)); int64_t seed = static_cast<const int64_t>(
RandomEngine::ThreadLocal()->RandInt(num_seeds_));
seeds.push_back(seed_edge_ids[seed]);
} }
worker_seeds = aten::VecToIdArray(seeds); worker_seeds = aten::VecToIdArray(seeds, seed_edges_->dtype.bits);
} }
EdgeArray arr = gptr_->FindEdges(worker_seeds); EdgeArray arr = gptr_->FindEdges(worker_seeds);
...@@ -1674,7 +1680,6 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject { ...@@ -1674,7 +1680,6 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
curr_batch_id_ = 0; curr_batch_id_ = 0;
// handle int64 overflow here // handle int64 overflow here
max_batch_id_ = (num_edges + batch_size - 1) / batch_size; max_batch_id_ = (num_edges + batch_size - 1) / batch_size;
// TODO(song): Tricky thing here to make sure gptr_ has coo cache // TODO(song): Tricky thing here to make sure gptr_ has coo cache
gptr_->FindEdge(0); gptr_->FindEdge(0);
} }
...@@ -1697,9 +1702,12 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject { ...@@ -1697,9 +1702,12 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
size_t n = batch_size_; size_t n = batch_size_;
size_t num_ids = 0; size_t num_ids = 0;
#pragma omp critical #pragma omp critical
{
num_ids = edge_selector_->SampleWithoutReplacement(n, &edge_ids); num_ids = edge_selector_->SampleWithoutReplacement(n, &edge_ids);
while (edge_ids.size() > num_ids) { }
edge_ids.pop_back(); edge_ids.resize(num_ids);
for (size_t i = 0; i < num_ids; ++i) {
edge_ids[i] = seed_edge_ids[edge_ids[i]];
} }
} else { } else {
// sampling of each edge is a standalone event // sampling of each edge is a standalone event
...@@ -1708,6 +1716,7 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject { ...@@ -1708,6 +1716,7 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
edge_ids[i] = seed_edge_ids[edge_id]; edge_ids[i] = seed_edge_ids[edge_id];
} }
} }
auto worker_seeds = aten::VecToIdArray(edge_ids, seed_edges_->dtype.bits); auto worker_seeds = aten::VecToIdArray(edge_ids, seed_edges_->dtype.bits);
EdgeArray arr = gptr_->FindEdges(worker_seeds); EdgeArray arr = gptr_->FindEdges(worker_seeds);
...@@ -1716,7 +1725,6 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject { ...@@ -1716,7 +1725,6 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
std::vector<dgl_id_t> src_vec(src_ids, src_ids + batch_size_); std::vector<dgl_id_t> src_vec(src_ids, src_ids + batch_size_);
std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + batch_size_); std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + batch_size_);
// TODO(zhengda) what if there are duplicates in the src and dst vectors. // TODO(zhengda) what if there are duplicates in the src and dst vectors.
Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false); Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false);
positive_subgs[i] = ConvertRef(subg); positive_subgs[i] = ConvertRef(subg);
// For PBG negative sampling, we accept "PBG-head" for corrupting head // For PBG negative sampling, we accept "PBG-head" for corrupting head
......
...@@ -665,6 +665,38 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size): ...@@ -665,6 +665,38 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
assert np.allclose(node_rate, node_rate_a * 5, atol=0.002) assert np.allclose(node_rate, node_rate_a * 5, atol=0.002)
assert np.allclose(node_rate_a, node_rate_b, atol=0.0002) assert np.allclose(node_rate_a, node_rate_b, atol=0.0002)
def check_positive_edge_sampler():
g = generate_rand_graph(1000)
num_edges = g.number_of_edges()
edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 1, dtype=np.float32)), F.cpu())
edge_weight[num_edges-1] = num_edges ** 2
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
# Correctness check
# Test the homogeneous graph.
batch_size = 128
edge_sampled = np.full((num_edges,), 0, dtype=np.int32)
for pos_edges in EdgeSampler(g, batch_size,
reset=False,
edge_weight=edge_weight):
_, _, pos_leid = pos_edges.all_edges(form='all', order='eid')
np.add.at(edge_sampled, F.asnumpy(pos_edges.parent_eid[pos_leid]), 1)
truth = np.full((num_edges,), 1, dtype=np.int32)
edge_sampled = edge_sampled[:num_edges]
assert np.array_equal(truth, edge_sampled)
edge_sampled = np.full((num_edges,), 0, dtype=np.int32)
for pos_edges in EdgeSampler(g, batch_size,
reset=False,
shuffle=True,
edge_weight=edge_weight):
_, _, pos_leid = pos_edges.all_edges(form='all', order='eid')
np.add.at(edge_sampled, F.asnumpy(pos_edges.parent_eid[pos_leid]), 1)
truth = np.full((num_edges,), 1, dtype=np.int32)
edge_sampled = edge_sampled[:num_edges]
assert np.array_equal(truth, edge_sampled)
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support item assignment") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support item assignment")
def test_negative_sampler(): def test_negative_sampler():
...@@ -674,6 +706,7 @@ def test_negative_sampler(): ...@@ -674,6 +706,7 @@ def test_negative_sampler():
check_weighted_negative_sampler('PBG-head', False, 10) check_weighted_negative_sampler('PBG-head', False, 10)
check_weighted_negative_sampler('head', True, 10) check_weighted_negative_sampler('head', True, 10)
check_weighted_negative_sampler('head', False, 10) check_weighted_negative_sampler('head', False, 10)
check_positive_edge_sampler()
#disable this check for now. It might take too long time. #disable this check for now. It might take too long time.
#check_negative_sampler('head', False, 100) #check_negative_sampler('head', False, 100)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment