Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2d489617
Unverified
Commit
2d489617
authored
Sep 23, 2019
by
Da Zheng
Committed by
GitHub
Sep 23, 2019
Browse files
[Feature] find the existence of negative edges. (#875)
* find the existence of negative edges. * add comment. * fix test.
parent
02fe316d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
148 additions
and
25 deletions
+148
-25
include/dgl/graph_interface.h
include/dgl/graph_interface.h
+6
-0
python/dgl/contrib/sampling/sampler.py
python/dgl/contrib/sampling/sampler.py
+15
-2
src/graph/sampler.cc
src/graph/sampler.cc
+92
-11
tests/compute/test_sampler.py
tests/compute/test_sampler.py
+35
-12
No files found.
include/dgl/graph_interface.h
View file @
2d489617
...
@@ -372,6 +372,12 @@ struct Subgraph : public runtime::Object {
...
@@ -372,6 +372,12 @@ struct Subgraph : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO
(
Subgraph
,
runtime
::
Object
);
DGL_DECLARE_OBJECT_TYPE_INFO
(
Subgraph
,
runtime
::
Object
);
};
};
/*! \brief Subgraph data structure for negative subgraph */
struct
NegSubgraph
:
public
Subgraph
{
/*! \brief The existence of the negative edges in the parent graph. */
IdArray
exist
;
};
// Define SubgraphRef
// Define SubgraphRef
DGL_DEFINE_OBJECT_REF
(
SubgraphRef
,
Subgraph
);
DGL_DEFINE_OBJECT_REF
(
SubgraphRef
,
Subgraph
);
...
...
python/dgl/contrib/sampling/sampler.py
View file @
2d489617
...
@@ -7,6 +7,7 @@ from numbers import Integral
...
@@ -7,6 +7,7 @@ from numbers import Integral
import
traceback
import
traceback
from
..._ffi.function
import
_init_api
from
..._ffi.function
import
_init_api
from
..._ffi.ndarray
import
empty
from
...
import
utils
from
...
import
utils
from
...nodeflow
import
NodeFlow
from
...nodeflow
import
NodeFlow
from
...
import
backend
as
F
from
...
import
backend
as
F
...
@@ -496,11 +497,20 @@ class EdgeSampler(object):
...
@@ -496,11 +497,20 @@ class EdgeSampler(object):
prefetch
=
False
,
prefetch
=
False
,
negative_mode
=
""
,
negative_mode
=
""
,
neg_sample_size
=
0
,
neg_sample_size
=
0
,
exclude_positive
=
False
):
exclude_positive
=
False
,
relations
=
None
):
self
.
_g
=
g
self
.
_g
=
g
if
self
.
immutable_only
and
not
g
.
_graph
.
is_readonly
():
if
self
.
immutable_only
and
not
g
.
_graph
.
is_readonly
():
raise
NotImplementedError
(
"This loader only support read-only graphs."
)
raise
NotImplementedError
(
"This loader only support read-only graphs."
)
if
relations
is
None
:
relations
=
empty
((
0
,),
'int64'
)
else
:
relations
=
utils
.
toindex
(
relations
)
relations
=
relations
.
todgltensor
()
assert
g
.
number_of_edges
()
==
len
(
relations
)
self
.
_relations
=
relations
self
.
_batch_size
=
int
(
batch_size
)
self
.
_batch_size
=
int
(
batch_size
)
if
seed_edges
is
None
:
if
seed_edges
is
None
:
...
@@ -544,7 +554,8 @@ class EdgeSampler(object):
...
@@ -544,7 +554,8 @@ class EdgeSampler(object):
self
.
_num_workers
,
# num batches
self
.
_num_workers
,
# num batches
self
.
_negative_mode
,
self
.
_negative_mode
,
self
.
_neg_sample_size
,
self
.
_neg_sample_size
,
self
.
_exclude_positive
)
self
.
_exclude_positive
,
self
.
_relations
)
if
len
(
subgs
)
==
0
:
if
len
(
subgs
)
==
0
:
return
[]
return
[]
...
@@ -559,6 +570,8 @@ class EdgeSampler(object):
...
@@ -559,6 +570,8 @@ class EdgeSampler(object):
for
i
in
range
(
num_pos
):
for
i
in
range
(
num_pos
):
pos_subg
=
subgraph
.
DGLSubGraph
(
self
.
g
,
subgs
[
i
])
pos_subg
=
subgraph
.
DGLSubGraph
(
self
.
g
,
subgs
[
i
])
neg_subg
=
subgraph
.
DGLSubGraph
(
self
.
g
,
subgs
[
i
+
num_pos
])
neg_subg
=
subgraph
.
DGLSubGraph
(
self
.
g
,
subgs
[
i
+
num_pos
])
exist
=
_CAPI_GetNegEdgeExistence
(
subgs
[
i
+
num_pos
]);
neg_subg
.
edata
[
'exist'
]
=
utils
.
toindex
(
exist
).
tousertensor
()
rets
.
append
((
pos_subg
,
neg_subg
))
rets
.
append
((
pos_subg
,
neg_subg
))
return
rets
return
rets
...
...
src/graph/sampler.cc
View file @
2d489617
...
@@ -907,9 +907,63 @@ inline bool is_neg_head_mode(const std::string &mode) {
...
@@ -907,9 +907,63 @@ inline bool is_neg_head_mode(const std::string &mode) {
return
mode
==
"head"
;
return
mode
==
"head"
;
}
}
Subgraph
NegEdgeSubgraph
(
GraphPtr
gptr
,
const
Subgraph
&
pos_subg
,
IdArray
GetGlobalVid
(
IdArray
induced_nid
,
IdArray
subg_nid
)
{
const
std
::
string
&
neg_mode
,
IdArray
gnid
=
IdArray
::
Empty
({
subg_nid
->
shape
[
0
]},
subg_nid
->
dtype
,
subg_nid
->
ctx
);
int
neg_sample_size
,
bool
exclude_positive
)
{
const
dgl_id_t
*
induced_nid_data
=
static_cast
<
dgl_id_t
*>
(
induced_nid
->
data
);
const
dgl_id_t
*
subg_nid_data
=
static_cast
<
dgl_id_t
*>
(
subg_nid
->
data
);
dgl_id_t
*
gnid_data
=
static_cast
<
dgl_id_t
*>
(
gnid
->
data
);
for
(
int64_t
i
=
0
;
i
<
subg_nid
->
shape
[
0
];
i
++
)
{
gnid_data
[
i
]
=
induced_nid_data
[
subg_nid_data
[
i
]];
}
return
gnid
;
}
IdArray
CheckExistence
(
GraphPtr
gptr
,
IdArray
neg_src
,
IdArray
neg_dst
,
IdArray
induced_nid
)
{
return
gptr
->
HasEdgesBetween
(
GetGlobalVid
(
induced_nid
,
neg_src
),
GetGlobalVid
(
induced_nid
,
neg_dst
));
}
IdArray
CheckExistence
(
GraphPtr
gptr
,
IdArray
relations
,
IdArray
neg_src
,
IdArray
neg_dst
,
IdArray
induced_nid
,
IdArray
neg_eid
)
{
neg_src
=
GetGlobalVid
(
induced_nid
,
neg_src
);
neg_dst
=
GetGlobalVid
(
induced_nid
,
neg_dst
);
BoolArray
exist
=
gptr
->
HasEdgesBetween
(
neg_src
,
neg_dst
);
dgl_id_t
*
neg_dst_data
=
static_cast
<
dgl_id_t
*>
(
neg_dst
->
data
);
dgl_id_t
*
neg_src_data
=
static_cast
<
dgl_id_t
*>
(
neg_src
->
data
);
dgl_id_t
*
neg_eid_data
=
static_cast
<
dgl_id_t
*>
(
neg_eid
->
data
);
dgl_id_t
*
relation_data
=
static_cast
<
dgl_id_t
*>
(
relations
->
data
);
// TODO(zhengda) is this right?
dgl_id_t
*
exist_data
=
static_cast
<
dgl_id_t
*>
(
exist
->
data
);
int64_t
num_neg_edges
=
neg_src
->
shape
[
0
];
for
(
int64_t
i
=
0
;
i
<
num_neg_edges
;
i
++
)
{
// If the edge doesn't exist, we don't need to do anything.
if
(
!
exist_data
[
i
])
continue
;
// If the edge exists, we need to double check if the relations match.
// If they match, this negative edge isn't really a negative edge.
dgl_id_t
eid1
=
neg_eid_data
[
i
];
dgl_id_t
orig_neg_rel1
=
relation_data
[
eid1
];
IdArray
eids
=
gptr
->
EdgeId
(
neg_src_data
[
i
],
neg_dst_data
[
i
]);
dgl_id_t
*
eid_data
=
static_cast
<
dgl_id_t
*>
(
eids
->
data
);
int64_t
num_edges_between
=
eids
->
shape
[
0
];
bool
same_rel
=
false
;
for
(
int64_t
j
=
0
;
j
<
num_edges_between
;
j
++
)
{
dgl_id_t
neg_rel1
=
relation_data
[
eid_data
[
j
]];
if
(
neg_rel1
==
orig_neg_rel1
)
{
same_rel
=
true
;
break
;
}
}
exist_data
[
i
]
=
same_rel
;
}
return
exist
;
}
NegSubgraph
NegEdgeSubgraph
(
GraphPtr
gptr
,
IdArray
relations
,
const
Subgraph
&
pos_subg
,
const
std
::
string
&
neg_mode
,
int
neg_sample_size
,
bool
exclude_positive
)
{
int64_t
num_tot_nodes
=
gptr
->
NumVertices
();
int64_t
num_tot_nodes
=
gptr
->
NumVertices
();
bool
is_multigraph
=
gptr
->
IsMultigraph
();
bool
is_multigraph
=
gptr
->
IsMultigraph
();
std
::
vector
<
IdArray
>
adj
=
pos_subg
.
graph
->
GetAdj
(
false
,
"coo"
);
std
::
vector
<
IdArray
>
adj
=
pos_subg
.
graph
->
GetAdj
(
false
,
"coo"
);
...
@@ -991,20 +1045,28 @@ Subgraph NegEdgeSubgraph(GraphPtr gptr, const Subgraph &pos_subg,
...
@@ -991,20 +1045,28 @@ Subgraph NegEdgeSubgraph(GraphPtr gptr, const Subgraph &pos_subg,
induced_neg_vid_data
[
it
->
second
]
=
it
->
first
;
induced_neg_vid_data
[
it
->
second
]
=
it
->
first
;
}
}
Subgraph
neg_subg
;
Neg
Subgraph
neg_subg
;
// We sample negative vertices without replacement.
// We sample negative vertices without replacement.
// There shouldn't be duplicated edges.
// There shouldn't be duplicated edges.
COOPtr
neg_coo
(
new
COO
(
num_neg_nodes
,
neg_src
,
neg_dst
,
is_multigraph
));
COOPtr
neg_coo
(
new
COO
(
num_neg_nodes
,
neg_src
,
neg_dst
,
is_multigraph
));
neg_subg
.
graph
=
GraphPtr
(
new
ImmutableGraph
(
neg_coo
));
neg_subg
.
graph
=
GraphPtr
(
new
ImmutableGraph
(
neg_coo
));
neg_subg
.
induced_vertices
=
induced_neg_vid
;
neg_subg
.
induced_vertices
=
induced_neg_vid
;
neg_subg
.
induced_edges
=
induced_neg_eid
;
neg_subg
.
induced_edges
=
induced_neg_eid
;
// TODO(zhengda) we should provide an array of 1s if exclude_positive
if
(
relations
->
shape
[
0
]
==
0
)
{
neg_subg
.
exist
=
CheckExistence
(
gptr
,
neg_src
,
neg_dst
,
induced_neg_vid
);
}
else
{
neg_subg
.
exist
=
CheckExistence
(
gptr
,
relations
,
neg_src
,
neg_dst
,
induced_neg_vid
,
induced_neg_eid
);
}
return
neg_subg
;
return
neg_subg
;
}
}
Subgraph
PBGNegEdgeSubgraph
(
int64_t
num_tot_node
s
,
const
Subgraph
&
pos_subg
,
Neg
Subgraph
PBGNegEdgeSubgraph
(
GraphPtr
gptr
,
IdArray
relation
s
,
const
Subgraph
&
pos_subg
,
const
std
::
string
&
neg_mode
,
const
std
::
string
&
neg_mode
,
int
neg_sample_size
,
bool
is_multigraph
,
int
neg_sample_size
,
bool
is_multigraph
,
bool
exclude_positive
)
{
bool
exclude_positive
)
{
int64_t
num_tot_nodes
=
gptr
->
NumVertices
();
std
::
vector
<
IdArray
>
adj
=
pos_subg
.
graph
->
GetAdj
(
false
,
"coo"
);
std
::
vector
<
IdArray
>
adj
=
pos_subg
.
graph
->
GetAdj
(
false
,
"coo"
);
IdArray
coo
=
adj
[
0
];
IdArray
coo
=
adj
[
0
];
int64_t
num_pos_edges
=
coo
->
shape
[
0
]
/
2
;
int64_t
num_pos_edges
=
coo
->
shape
[
0
]
/
2
;
...
@@ -1107,13 +1169,19 @@ Subgraph PBGNegEdgeSubgraph(int64_t num_tot_nodes, const Subgraph &pos_subg,
...
@@ -1107,13 +1169,19 @@ Subgraph PBGNegEdgeSubgraph(int64_t num_tot_nodes, const Subgraph &pos_subg,
induced_neg_vid_data
[
it
->
second
]
=
it
->
first
;
induced_neg_vid_data
[
it
->
second
]
=
it
->
first
;
}
}
Subgraph
neg_subg
;
Neg
Subgraph
neg_subg
;
// We sample negative vertices without replacement.
// We sample negative vertices without replacement.
// There shouldn't be duplicated edges.
// There shouldn't be duplicated edges.
COOPtr
neg_coo
(
new
COO
(
num_neg_nodes
,
neg_src
,
neg_dst
,
is_multigraph
));
COOPtr
neg_coo
(
new
COO
(
num_neg_nodes
,
neg_src
,
neg_dst
,
is_multigraph
));
neg_subg
.
graph
=
GraphPtr
(
new
ImmutableGraph
(
neg_coo
));
neg_subg
.
graph
=
GraphPtr
(
new
ImmutableGraph
(
neg_coo
));
neg_subg
.
induced_vertices
=
induced_neg_vid
;
neg_subg
.
induced_vertices
=
induced_neg_vid
;
neg_subg
.
induced_edges
=
induced_neg_eid
;
neg_subg
.
induced_edges
=
induced_neg_eid
;
if
(
relations
->
shape
[
0
]
==
0
)
{
neg_subg
.
exist
=
CheckExistence
(
gptr
,
neg_src
,
neg_dst
,
induced_neg_vid
);
}
else
{
neg_subg
.
exist
=
CheckExistence
(
gptr
,
relations
,
neg_src
,
neg_dst
,
induced_neg_vid
,
induced_neg_eid
);
}
return
neg_subg
;
return
neg_subg
;
}
}
...
@@ -1121,6 +1189,10 @@ inline SubgraphRef ConvertRef(const Subgraph &subg) {
...
@@ -1121,6 +1189,10 @@ inline SubgraphRef ConvertRef(const Subgraph &subg) {
return
SubgraphRef
(
std
::
shared_ptr
<
Subgraph
>
(
new
Subgraph
(
subg
)));
return
SubgraphRef
(
std
::
shared_ptr
<
Subgraph
>
(
new
Subgraph
(
subg
)));
}
}
inline
SubgraphRef
ConvertRef
(
const
NegSubgraph
&
subg
)
{
return
SubgraphRef
(
std
::
shared_ptr
<
Subgraph
>
(
new
NegSubgraph
(
subg
)));
}
}
// namespace
}
// namespace
DGL_REGISTER_GLOBAL
(
"sampling._CAPI_UniformEdgeSampling"
)
DGL_REGISTER_GLOBAL
(
"sampling._CAPI_UniformEdgeSampling"
)
...
@@ -1134,6 +1206,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
...
@@ -1134,6 +1206,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
const
std
::
string
neg_mode
=
args
[
5
];
const
std
::
string
neg_mode
=
args
[
5
];
const
int
neg_sample_size
=
args
[
6
];
const
int
neg_sample_size
=
args
[
6
];
const
bool
exclude_positive
=
args
[
7
];
const
bool
exclude_positive
=
args
[
7
];
IdArray
relations
=
args
[
8
];
// process args
// process args
auto
gptr
=
std
::
dynamic_pointer_cast
<
ImmutableGraph
>
(
g
.
sptr
());
auto
gptr
=
std
::
dynamic_pointer_cast
<
ImmutableGraph
>
(
g
.
sptr
());
CHECK
(
gptr
)
<<
"sampling isn't implemented in mutable graph"
;
CHECK
(
gptr
)
<<
"sampling isn't implemented in mutable graph"
;
...
@@ -1165,13 +1238,13 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
...
@@ -1165,13 +1238,13 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
// For PBG negative sampling, we accept "PBG-head" for corrupting head
// For PBG negative sampling, we accept "PBG-head" for corrupting head
// nodes and "PBG-tail" for corrupting tail nodes.
// nodes and "PBG-tail" for corrupting tail nodes.
if
(
neg_mode
.
substr
(
0
,
3
)
==
"PBG"
)
{
if
(
neg_mode
.
substr
(
0
,
3
)
==
"PBG"
)
{
Subgraph
neg_subg
=
PBGNegEdgeSubgraph
(
gptr
->
NumVertices
()
,
subg
,
Neg
Subgraph
neg_subg
=
PBGNegEdgeSubgraph
(
gptr
,
relations
,
subg
,
neg_mode
.
substr
(
4
),
neg_sample_size
,
neg_mode
.
substr
(
4
),
neg_sample_size
,
gptr
->
IsMultigraph
(),
exclude_positive
);
gptr
->
IsMultigraph
(),
exclude_positive
);
negative_subgs
[
i
]
=
ConvertRef
(
neg_subg
);
negative_subgs
[
i
]
=
ConvertRef
(
neg_subg
);
}
else
if
(
neg_mode
.
size
()
>
0
)
{
}
else
if
(
neg_mode
.
size
()
>
0
)
{
Subgraph
neg_subg
=
NegEdgeSubgraph
(
gptr
,
subg
,
neg_mode
,
neg_sample_size
,
Neg
Subgraph
neg_subg
=
NegEdgeSubgraph
(
gptr
,
relations
,
subg
,
neg_mode
,
neg_sample_size
,
exclude_positive
);
exclude_positive
);
negative_subgs
[
i
]
=
ConvertRef
(
neg_subg
);
negative_subgs
[
i
]
=
ConvertRef
(
neg_subg
);
}
}
}
}
...
@@ -1181,4 +1254,12 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
...
@@ -1181,4 +1254,12 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
*
rv
=
List
<
SubgraphRef
>
(
positive_subgs
);
*
rv
=
List
<
SubgraphRef
>
(
positive_subgs
);
});
});
DGL_REGISTER_GLOBAL
(
"sampling._CAPI_GetNegEdgeExistence"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
SubgraphRef
g
=
args
[
0
];
auto
gptr
=
std
::
dynamic_pointer_cast
<
NegSubgraph
>
(
g
.
sptr
());
*
rv
=
gptr
->
exist
;
});
}
// namespace dgl
}
// namespace dgl
tests/compute/test_sampler.py
View file @
2d489617
...
@@ -222,6 +222,8 @@ def test_setseed():
...
@@ -222,6 +222,8 @@ def test_setseed():
def
check_negative_sampler
(
mode
,
exclude_positive
):
def
check_negative_sampler
(
mode
,
exclude_positive
):
g
=
generate_rand_graph
(
100
)
g
=
generate_rand_graph
(
100
)
etype
=
np
.
random
.
randint
(
0
,
10
,
size
=
g
.
number_of_edges
(),
dtype
=
np
.
int64
)
g
.
edata
[
'etype'
]
=
F
.
tensor
(
etype
)
pos_gsrc
,
pos_gdst
,
pos_geid
=
g
.
all_edges
(
form
=
'all'
,
order
=
'eid'
)
pos_gsrc
,
pos_gdst
,
pos_geid
=
g
.
all_edges
(
form
=
'all'
,
order
=
'eid'
)
pos_map
=
{}
pos_map
=
{}
...
@@ -232,25 +234,20 @@ def check_negative_sampler(mode, exclude_positive):
...
@@ -232,25 +234,20 @@ def check_negative_sampler(mode, exclude_positive):
EdgeSampler
=
getattr
(
dgl
.
contrib
.
sampling
,
'EdgeSampler'
)
EdgeSampler
=
getattr
(
dgl
.
contrib
.
sampling
,
'EdgeSampler'
)
neg_size
=
10
neg_size
=
10
# Test the homogeneous graph.
for
pos_edges
,
neg_edges
in
EdgeSampler
(
g
,
50
,
for
pos_edges
,
neg_edges
in
EdgeSampler
(
g
,
50
,
negative_mode
=
mode
,
negative_mode
=
mode
,
neg_sample_size
=
neg_size
,
neg_sample_size
=
neg_size
,
exclude_positive
=
exclude_positive
):
exclude_positive
=
exclude_positive
):
pos_nid
=
pos_edges
.
parent_nid
pos_eid
=
pos_edges
.
parent_eid
pos_lsrc
,
pos_ldst
,
pos_leid
=
pos_edges
.
all_edges
(
form
=
'all'
,
order
=
'eid'
)
pos_lsrc
,
pos_ldst
,
pos_leid
=
pos_edges
.
all_edges
(
form
=
'all'
,
order
=
'eid'
)
pos_src
=
pos_nid
[
pos_lsrc
]
assert_array_equal
(
F
.
asnumpy
(
pos_edges
.
parent_eid
[
pos_leid
]),
pos_dst
=
pos_nid
[
pos_ldst
]
F
.
asnumpy
(
g
.
edge_ids
(
pos_edges
.
parent_nid
[
pos_lsrc
],
pos_eid
=
pos_eid
[
pos_leid
]
pos_edges
.
parent_nid
[
pos_ldst
])))
assert_array_equal
(
F
.
asnumpy
(
pos_eid
),
F
.
asnumpy
(
g
.
edge_ids
(
pos_src
,
pos_dst
)))
neg_lsrc
,
neg_ldst
,
neg_leid
=
neg_edges
.
all_edges
(
form
=
'all'
,
order
=
'eid'
)
neg_lsrc
,
neg_ldst
,
neg_leid
=
neg_edges
.
all_edges
(
form
=
'all'
,
order
=
'eid'
)
neg_nid
=
neg_edges
.
parent_nid
neg_src
=
neg_edges
.
parent_nid
[
neg_lsrc
]
neg_eid
=
neg_edges
.
parent_eid
neg_dst
=
neg_edges
.
parent_nid
[
neg_ldst
]
neg_src
=
neg_nid
[
neg_lsrc
]
neg_eid
=
neg_edges
.
parent_eid
[
neg_leid
]
neg_dst
=
neg_nid
[
neg_ldst
]
neg_eid
=
neg_eid
[
neg_leid
]
for
i
in
range
(
len
(
neg_eid
)):
for
i
in
range
(
len
(
neg_eid
)):
neg_d
=
int
(
F
.
asnumpy
(
neg_dst
[
i
]))
neg_d
=
int
(
F
.
asnumpy
(
neg_dst
[
i
]))
neg_e
=
int
(
F
.
asnumpy
(
neg_eid
[
i
]))
neg_e
=
int
(
F
.
asnumpy
(
neg_eid
[
i
]))
...
@@ -258,6 +255,32 @@ def check_negative_sampler(mode, exclude_positive):
...
@@ -258,6 +255,32 @@ def check_negative_sampler(mode, exclude_positive):
if
exclude_positive
:
if
exclude_positive
:
assert
int
(
F
.
asnumpy
(
neg_src
[
i
]))
!=
pos_map
[(
neg_d
,
neg_e
)]
assert
int
(
F
.
asnumpy
(
neg_src
[
i
]))
!=
pos_map
[(
neg_d
,
neg_e
)]
exist
=
neg_edges
.
edata
[
'exist'
]
if
exclude_positive
:
assert
np
.
sum
(
F
.
asnumpy
(
exist
)
==
0
)
==
len
(
exist
)
else
:
assert
F
.
array_equal
(
g
.
has_edges_between
(
neg_src
,
neg_dst
),
exist
)
# Test the knowledge graph.
for
_
,
neg_edges
in
EdgeSampler
(
g
,
50
,
negative_mode
=
mode
,
neg_sample_size
=
neg_size
,
exclude_positive
=
exclude_positive
,
relations
=
g
.
edata
[
'etype'
]):
neg_lsrc
,
neg_ldst
,
neg_leid
=
neg_edges
.
all_edges
(
form
=
'all'
,
order
=
'eid'
)
neg_src
=
neg_edges
.
parent_nid
[
neg_lsrc
]
neg_dst
=
neg_edges
.
parent_nid
[
neg_ldst
]
neg_eid
=
neg_edges
.
parent_eid
[
neg_leid
]
exists
=
neg_edges
.
edata
[
'exist'
]
neg_edges
.
edata
[
'etype'
]
=
g
.
edata
[
'etype'
][
neg_eid
]
for
i
in
range
(
len
(
neg_eid
)):
u
,
v
=
F
.
asnumpy
(
neg_src
[
i
]),
F
.
asnumpy
(
neg_dst
[
i
])
if
g
.
has_edge_between
(
u
,
v
):
eid
=
g
.
edge_id
(
u
,
v
)
etype
=
g
.
edata
[
'etype'
][
eid
]
exist
=
neg_edges
.
edata
[
'etype'
][
i
]
==
etype
assert
F
.
asnumpy
(
exists
[
i
])
==
F
.
asnumpy
(
exist
)
def
test_negative_sampler
():
def
test_negative_sampler
():
check_negative_sampler
(
'head'
,
True
)
check_negative_sampler
(
'head'
,
True
)
check_negative_sampler
(
'PBG-head'
,
False
)
check_negative_sampler
(
'PBG-head'
,
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment