Unverified Commit 77f4287a authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fixes the redundancy parameter being used wrong in global negative sampling (#3657)

* oops

* test
parent 48cbea72
......@@ -27,7 +27,7 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
double redundancy) {
const int64_t num_row = csr.num_rows;
const int64_t num_col = csr.num_cols;
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * redundancy);
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * (1 + redundancy));
IdArray row = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
IdArray col = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
IdType* row_data = row.Ptr<IdType>();
......
......@@ -140,7 +140,7 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
auto dtype = csr.indptr->dtype;
const int64_t num_row = csr.num_rows;
const int64_t num_col = csr.num_cols;
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * redundancy);
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * (1 + redundancy));
IdArray row = Full<IdType>(-1, num_actual_samples, ctx);
IdArray col = Full<IdType>(-1, num_actual_samples, ctx);
IdArray out_row = IdArray::Empty({num_actual_samples}, dtype, ctx);
......
......@@ -892,10 +892,10 @@ def test_sample_neighbors_exclude_edges_homoG(dtype):
@pytest.mark.parametrize('dtype', ['int32', 'int64'])
def test_global_uniform_negative_sampling(dtype):
g = dgl.graph((np.random.randint(0, 20, (10,)), np.random.randint(0, 20, (10,)))).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True)
assert len(src) > 0
assert len(dst) > 0
g = dgl.graph(([], []), num_nodes=1000).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 2000, False, True)
assert len(src) == 2000
assert len(dst) == 2000
g = dgl.graph((np.random.randint(0, 20, (300,)), np.random.randint(0, 20, (300,)))).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment