Unverified Commit 2aad1c0b authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix GPU global negative sampling code (#3653)

* fix GPU global negative sampling code

* Update negative_sampling.cu
parent 8d14a739
......@@ -42,8 +42,13 @@ __global__ void _GlobalUniformNegativeSamplingKernel(
while (tx < num_samples) {
for (int i = 0; i < num_trials; ++i) {
uint4 result = curand4(&rng);
IdType u = ((result.x << 32) | result.y) % num_row;
IdType v = ((result.z << 32) | result.w) % num_col;
// Turns out that result.x is always 0 with the above RNG.
uint64_t y_hi = result.y >> 16;
uint64_t y_lo = result.y & 0xFFFF;
uint64_t z = static_cast<uint64_t>(result.z);
uint64_t w = static_cast<uint64_t>(result.w);
int64_t u = static_cast<int64_t>(((y_lo << 32L) | z) % num_row);
int64_t v = static_cast<int64_t>(((y_hi << 32L) | w) % num_col);
if (exclude_self_loops && (u == v))
continue;
......
......@@ -892,6 +892,11 @@ def test_sample_neighbors_exclude_edges_homoG(dtype):
@pytest.mark.parametrize('dtype', ['int32', 'int64'])
def test_global_uniform_negative_sampling(dtype):
g = dgl.graph((np.random.randint(0, 20, (10,)), np.random.randint(0, 20, (10,)))).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True)
assert len(src) > 0
assert len(dst) > 0
g = dgl.graph((np.random.randint(0, 20, (300,)), np.random.randint(0, 20, (300,)))).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True)
assert not F.asnumpy(g.has_edges_between(src, dst)).any()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment