Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
77f4287a
Unverified
Commit
77f4287a
authored
Jan 17, 2022
by
Quan (Andy) Gan
Committed by
GitHub
Jan 17, 2022
Browse files
[Bugfix] Fixes the redundancy parameter being used wrong in global negative sampling (#3657)
* oops * test
parent
48cbea72
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
6 deletions
+6
-6
src/array/cpu/negative_sampling.cc
src/array/cpu/negative_sampling.cc
+1
-1
src/array/cuda/negative_sampling.cu
src/array/cuda/negative_sampling.cu
+1
-1
tests/compute/test_sampling.py
tests/compute/test_sampling.py
+4
-4
No files found.
src/array/cpu/negative_sampling.cc
View file @
77f4287a
...
...
@@ -27,7 +27,7 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
double
redundancy
)
{
const
int64_t
num_row
=
csr
.
num_rows
;
const
int64_t
num_col
=
csr
.
num_cols
;
const
int64_t
num_actual_samples
=
static_cast
<
int64_t
>
(
num_samples
*
redundancy
);
const
int64_t
num_actual_samples
=
static_cast
<
int64_t
>
(
num_samples
*
(
1
+
redundancy
)
)
;
IdArray
row
=
Full
<
IdType
>
(
-
1
,
num_actual_samples
,
csr
.
indptr
->
ctx
);
IdArray
col
=
Full
<
IdType
>
(
-
1
,
num_actual_samples
,
csr
.
indptr
->
ctx
);
IdType
*
row_data
=
row
.
Ptr
<
IdType
>
();
...
...
src/array/cuda/negative_sampling.cu
View file @
77f4287a
...
...
@@ -140,7 +140,7 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
auto
dtype
=
csr
.
indptr
->
dtype
;
const
int64_t
num_row
=
csr
.
num_rows
;
const
int64_t
num_col
=
csr
.
num_cols
;
const
int64_t
num_actual_samples
=
static_cast
<
int64_t
>
(
num_samples
*
redundancy
);
const
int64_t
num_actual_samples
=
static_cast
<
int64_t
>
(
num_samples
*
(
1
+
redundancy
)
)
;
IdArray
row
=
Full
<
IdType
>
(
-
1
,
num_actual_samples
,
ctx
);
IdArray
col
=
Full
<
IdType
>
(
-
1
,
num_actual_samples
,
ctx
);
IdArray
out_row
=
IdArray
::
Empty
({
num_actual_samples
},
dtype
,
ctx
);
...
...
tests/compute/test_sampling.py
View file @
77f4287a
...
...
@@ -892,10 +892,10 @@ def test_sample_neighbors_exclude_edges_homoG(dtype):
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
'int32'
,
'int64'
])
def
test_global_uniform_negative_sampling
(
dtype
):
g
=
dgl
.
graph
((
np
.
random
.
randint
(
0
,
20
,
(
10
,)),
np
.
random
.
randint
(
0
,
20
,
(
10
,)))
).
to
(
F
.
ctx
())
src
,
dst
=
dgl
.
sampling
.
global_uniform_negative_sampling
(
g
,
20
,
False
,
True
)
assert
len
(
src
)
>
0
assert
len
(
dst
)
>
0
g
=
dgl
.
graph
((
[],
[]),
num_nodes
=
1000
).
to
(
F
.
ctx
())
src
,
dst
=
dgl
.
sampling
.
global_uniform_negative_sampling
(
g
,
20
00
,
False
,
True
)
assert
len
(
src
)
==
200
0
assert
len
(
dst
)
==
200
0
g
=
dgl
.
graph
((
np
.
random
.
randint
(
0
,
20
,
(
300
,)),
np
.
random
.
randint
(
0
,
20
,
(
300
,)))).
to
(
F
.
ctx
())
src
,
dst
=
dgl
.
sampling
.
global_uniform_negative_sampling
(
g
,
20
,
False
,
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment