Unverified Commit 2d489617 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Feature] find the existence of negative edges. (#875)

* find the existence of negative edges.

* add comment.

* fix test.
parent 02fe316d
...@@ -372,6 +372,12 @@ struct Subgraph : public runtime::Object { ...@@ -372,6 +372,12 @@ struct Subgraph : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object);
}; };
/*! \brief Subgraph data structure for negative subgraph */
struct NegSubgraph: public Subgraph {
/*! \brief The existence of the negative edges in the parent graph. */
IdArray exist;
};
// Define SubgraphRef // Define SubgraphRef
DGL_DEFINE_OBJECT_REF(SubgraphRef, Subgraph); DGL_DEFINE_OBJECT_REF(SubgraphRef, Subgraph);
......
...@@ -7,6 +7,7 @@ from numbers import Integral ...@@ -7,6 +7,7 @@ from numbers import Integral
import traceback import traceback
from ..._ffi.function import _init_api from ..._ffi.function import _init_api
from ..._ffi.ndarray import empty
from ... import utils from ... import utils
from ...nodeflow import NodeFlow from ...nodeflow import NodeFlow
from ... import backend as F from ... import backend as F
...@@ -496,11 +497,20 @@ class EdgeSampler(object): ...@@ -496,11 +497,20 @@ class EdgeSampler(object):
prefetch=False, prefetch=False,
negative_mode="", negative_mode="",
neg_sample_size=0, neg_sample_size=0,
exclude_positive=False): exclude_positive=False,
relations=None):
self._g = g self._g = g
if self.immutable_only and not g._graph.is_readonly(): if self.immutable_only and not g._graph.is_readonly():
raise NotImplementedError("This loader only support read-only graphs.") raise NotImplementedError("This loader only support read-only graphs.")
if relations is None:
relations = empty((0,), 'int64')
else:
relations = utils.toindex(relations)
relations = relations.todgltensor()
assert g.number_of_edges() == len(relations)
self._relations = relations
self._batch_size = int(batch_size) self._batch_size = int(batch_size)
if seed_edges is None: if seed_edges is None:
...@@ -544,7 +554,8 @@ class EdgeSampler(object): ...@@ -544,7 +554,8 @@ class EdgeSampler(object):
self._num_workers, # num batches self._num_workers, # num batches
self._negative_mode, self._negative_mode,
self._neg_sample_size, self._neg_sample_size,
self._exclude_positive) self._exclude_positive,
self._relations)
if len(subgs) == 0: if len(subgs) == 0:
return [] return []
...@@ -559,6 +570,8 @@ class EdgeSampler(object): ...@@ -559,6 +570,8 @@ class EdgeSampler(object):
for i in range(num_pos): for i in range(num_pos):
pos_subg = subgraph.DGLSubGraph(self.g, subgs[i]) pos_subg = subgraph.DGLSubGraph(self.g, subgs[i])
neg_subg = subgraph.DGLSubGraph(self.g, subgs[i + num_pos]) neg_subg = subgraph.DGLSubGraph(self.g, subgs[i + num_pos])
exist = _CAPI_GetNegEdgeExistence(subgs[i + num_pos]);
neg_subg.edata['exist'] = utils.toindex(exist).tousertensor()
rets.append((pos_subg, neg_subg)) rets.append((pos_subg, neg_subg))
return rets return rets
......
...@@ -907,7 +907,61 @@ inline bool is_neg_head_mode(const std::string &mode) { ...@@ -907,7 +907,61 @@ inline bool is_neg_head_mode(const std::string &mode) {
return mode == "head"; return mode == "head";
} }
Subgraph NegEdgeSubgraph(GraphPtr gptr, const Subgraph &pos_subg, IdArray GetGlobalVid(IdArray induced_nid, IdArray subg_nid) {
IdArray gnid = IdArray::Empty({subg_nid->shape[0]}, subg_nid->dtype, subg_nid->ctx);
const dgl_id_t *induced_nid_data = static_cast<dgl_id_t *>(induced_nid->data);
const dgl_id_t *subg_nid_data = static_cast<dgl_id_t *>(subg_nid->data);
dgl_id_t *gnid_data = static_cast<dgl_id_t *>(gnid->data);
for (int64_t i = 0; i < subg_nid->shape[0]; i++) {
gnid_data[i] = induced_nid_data[subg_nid_data[i]];
}
return gnid;
}
IdArray CheckExistence(GraphPtr gptr, IdArray neg_src, IdArray neg_dst,
IdArray induced_nid) {
return gptr->HasEdgesBetween(GetGlobalVid(induced_nid, neg_src),
GetGlobalVid(induced_nid, neg_dst));
}
IdArray CheckExistence(GraphPtr gptr, IdArray relations,
IdArray neg_src, IdArray neg_dst,
IdArray induced_nid, IdArray neg_eid) {
neg_src = GetGlobalVid(induced_nid, neg_src);
neg_dst = GetGlobalVid(induced_nid, neg_dst);
BoolArray exist = gptr->HasEdgesBetween(neg_src, neg_dst);
dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data);
dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data);
dgl_id_t *neg_eid_data = static_cast<dgl_id_t *>(neg_eid->data);
dgl_id_t *relation_data = static_cast<dgl_id_t *>(relations->data);
// TODO(zhengda) is this right?
dgl_id_t *exist_data = static_cast<dgl_id_t *>(exist->data);
int64_t num_neg_edges = neg_src->shape[0];
for (int64_t i = 0; i < num_neg_edges; i++) {
// If the edge doesn't exist, we don't need to do anything.
if (!exist_data[i])
continue;
// If the edge exists, we need to double check if the relations match.
// If they match, this negative edge isn't really a negative edge.
dgl_id_t eid1 = neg_eid_data[i];
dgl_id_t orig_neg_rel1 = relation_data[eid1];
IdArray eids = gptr->EdgeId(neg_src_data[i], neg_dst_data[i]);
dgl_id_t *eid_data = static_cast<dgl_id_t *>(eids->data);
int64_t num_edges_between = eids->shape[0];
bool same_rel = false;
for (int64_t j = 0; j < num_edges_between; j++) {
dgl_id_t neg_rel1 = relation_data[eid_data[j]];
if (neg_rel1 == orig_neg_rel1) {
same_rel = true;
break;
}
}
exist_data[i] = same_rel;
}
return exist;
}
NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &pos_subg,
const std::string &neg_mode, const std::string &neg_mode,
int neg_sample_size, bool exclude_positive) { int neg_sample_size, bool exclude_positive) {
int64_t num_tot_nodes = gptr->NumVertices(); int64_t num_tot_nodes = gptr->NumVertices();
...@@ -991,20 +1045,28 @@ Subgraph NegEdgeSubgraph(GraphPtr gptr, const Subgraph &pos_subg, ...@@ -991,20 +1045,28 @@ Subgraph NegEdgeSubgraph(GraphPtr gptr, const Subgraph &pos_subg,
induced_neg_vid_data[it->second] = it->first; induced_neg_vid_data[it->second] = it->first;
} }
Subgraph neg_subg; NegSubgraph neg_subg;
// We sample negative vertices without replacement. // We sample negative vertices without replacement.
// There shouldn't be duplicated edges. // There shouldn't be duplicated edges.
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, is_multigraph)); COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, is_multigraph));
neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo)); neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));
neg_subg.induced_vertices = induced_neg_vid; neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid; neg_subg.induced_edges = induced_neg_eid;
// TODO(zhengda) we should provide an array of 1s if exclude_positive
if (relations->shape[0] == 0) {
neg_subg.exist = CheckExistence(gptr, neg_src, neg_dst, induced_neg_vid);
} else {
neg_subg.exist = CheckExistence(gptr, relations, neg_src, neg_dst,
induced_neg_vid, induced_neg_eid);
}
return neg_subg; return neg_subg;
} }
Subgraph PBGNegEdgeSubgraph(int64_t num_tot_nodes, const Subgraph &pos_subg, NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &pos_subg,
const std::string &neg_mode, const std::string &neg_mode,
int neg_sample_size, bool is_multigraph, int neg_sample_size, bool is_multigraph,
bool exclude_positive) { bool exclude_positive) {
int64_t num_tot_nodes = gptr->NumVertices();
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo"); std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0]; IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2; int64_t num_pos_edges = coo->shape[0] / 2;
...@@ -1107,13 +1169,19 @@ Subgraph PBGNegEdgeSubgraph(int64_t num_tot_nodes, const Subgraph &pos_subg, ...@@ -1107,13 +1169,19 @@ Subgraph PBGNegEdgeSubgraph(int64_t num_tot_nodes, const Subgraph &pos_subg,
induced_neg_vid_data[it->second] = it->first; induced_neg_vid_data[it->second] = it->first;
} }
Subgraph neg_subg; NegSubgraph neg_subg;
// We sample negative vertices without replacement. // We sample negative vertices without replacement.
// There shouldn't be duplicated edges. // There shouldn't be duplicated edges.
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, is_multigraph)); COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, is_multigraph));
neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo)); neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));
neg_subg.induced_vertices = induced_neg_vid; neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid; neg_subg.induced_edges = induced_neg_eid;
if (relations->shape[0] == 0) {
neg_subg.exist = CheckExistence(gptr, neg_src, neg_dst, induced_neg_vid);
} else {
neg_subg.exist = CheckExistence(gptr, relations, neg_src, neg_dst,
induced_neg_vid, induced_neg_eid);
}
return neg_subg; return neg_subg;
} }
...@@ -1121,6 +1189,10 @@ inline SubgraphRef ConvertRef(const Subgraph &subg) { ...@@ -1121,6 +1189,10 @@ inline SubgraphRef ConvertRef(const Subgraph &subg) {
return SubgraphRef(std::shared_ptr<Subgraph>(new Subgraph(subg))); return SubgraphRef(std::shared_ptr<Subgraph>(new Subgraph(subg)));
} }
inline SubgraphRef ConvertRef(const NegSubgraph &subg) {
return SubgraphRef(std::shared_ptr<Subgraph>(new NegSubgraph(subg)));
}
} // namespace } // namespace
DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling") DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
...@@ -1134,6 +1206,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling") ...@@ -1134,6 +1206,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
const std::string neg_mode = args[5]; const std::string neg_mode = args[5];
const int neg_sample_size = args[6]; const int neg_sample_size = args[6];
const bool exclude_positive = args[7]; const bool exclude_positive = args[7];
IdArray relations = args[8];
// process args // process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
...@@ -1165,12 +1238,12 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling") ...@@ -1165,12 +1238,12 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
// For PBG negative sampling, we accept "PBG-head" for corrupting head // For PBG negative sampling, we accept "PBG-head" for corrupting head
// nodes and "PBG-tail" for corrupting tail nodes. // nodes and "PBG-tail" for corrupting tail nodes.
if (neg_mode.substr(0, 3) == "PBG") { if (neg_mode.substr(0, 3) == "PBG") {
Subgraph neg_subg = PBGNegEdgeSubgraph(gptr->NumVertices(), subg, NegSubgraph neg_subg = PBGNegEdgeSubgraph(gptr, relations, subg,
neg_mode.substr(4), neg_sample_size, neg_mode.substr(4), neg_sample_size,
gptr->IsMultigraph(), exclude_positive); gptr->IsMultigraph(), exclude_positive);
negative_subgs[i] = ConvertRef(neg_subg); negative_subgs[i] = ConvertRef(neg_subg);
} else if (neg_mode.size() > 0) { } else if (neg_mode.size() > 0) {
Subgraph neg_subg = NegEdgeSubgraph(gptr, subg, neg_mode, neg_sample_size, NegSubgraph neg_subg = NegEdgeSubgraph(gptr, relations, subg, neg_mode, neg_sample_size,
exclude_positive); exclude_positive);
negative_subgs[i] = ConvertRef(neg_subg); negative_subgs[i] = ConvertRef(neg_subg);
} }
...@@ -1181,4 +1254,12 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling") ...@@ -1181,4 +1254,12 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
*rv = List<SubgraphRef>(positive_subgs); *rv = List<SubgraphRef>(positive_subgs);
}); });
DGL_REGISTER_GLOBAL("sampling._CAPI_GetNegEdgeExistence")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->exist;
});
} // namespace dgl } // namespace dgl
...@@ -222,6 +222,8 @@ def test_setseed(): ...@@ -222,6 +222,8 @@ def test_setseed():
def check_negative_sampler(mode, exclude_positive): def check_negative_sampler(mode, exclude_positive):
g = generate_rand_graph(100) g = generate_rand_graph(100)
etype = np.random.randint(0, 10, size=g.number_of_edges(), dtype=np.int64)
g.edata['etype'] = F.tensor(etype)
pos_gsrc, pos_gdst, pos_geid = g.all_edges(form='all', order='eid') pos_gsrc, pos_gdst, pos_geid = g.all_edges(form='all', order='eid')
pos_map = {} pos_map = {}
...@@ -232,25 +234,20 @@ def check_negative_sampler(mode, exclude_positive): ...@@ -232,25 +234,20 @@ def check_negative_sampler(mode, exclude_positive):
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler') EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
neg_size = 10 neg_size = 10
# Test the homogeneous graph.
for pos_edges, neg_edges in EdgeSampler(g, 50, for pos_edges, neg_edges in EdgeSampler(g, 50,
negative_mode=mode, negative_mode=mode,
neg_sample_size=neg_size, neg_sample_size=neg_size,
exclude_positive=exclude_positive): exclude_positive=exclude_positive):
pos_nid = pos_edges.parent_nid
pos_eid = pos_edges.parent_eid
pos_lsrc, pos_ldst, pos_leid = pos_edges.all_edges(form='all', order='eid') pos_lsrc, pos_ldst, pos_leid = pos_edges.all_edges(form='all', order='eid')
pos_src = pos_nid[pos_lsrc] assert_array_equal(F.asnumpy(pos_edges.parent_eid[pos_leid]),
pos_dst = pos_nid[pos_ldst] F.asnumpy(g.edge_ids(pos_edges.parent_nid[pos_lsrc],
pos_eid = pos_eid[pos_leid] pos_edges.parent_nid[pos_ldst])))
assert_array_equal(F.asnumpy(pos_eid), F.asnumpy(g.edge_ids(pos_src, pos_dst)))
neg_lsrc, neg_ldst, neg_leid = neg_edges.all_edges(form='all', order='eid') neg_lsrc, neg_ldst, neg_leid = neg_edges.all_edges(form='all', order='eid')
neg_nid = neg_edges.parent_nid neg_src = neg_edges.parent_nid[neg_lsrc]
neg_eid = neg_edges.parent_eid neg_dst = neg_edges.parent_nid[neg_ldst]
neg_src = neg_nid[neg_lsrc] neg_eid = neg_edges.parent_eid[neg_leid]
neg_dst = neg_nid[neg_ldst]
neg_eid = neg_eid[neg_leid]
for i in range(len(neg_eid)): for i in range(len(neg_eid)):
neg_d = int(F.asnumpy(neg_dst[i])) neg_d = int(F.asnumpy(neg_dst[i]))
neg_e = int(F.asnumpy(neg_eid[i])) neg_e = int(F.asnumpy(neg_eid[i]))
...@@ -258,6 +255,32 @@ def check_negative_sampler(mode, exclude_positive): ...@@ -258,6 +255,32 @@ def check_negative_sampler(mode, exclude_positive):
if exclude_positive: if exclude_positive:
assert int(F.asnumpy(neg_src[i])) != pos_map[(neg_d, neg_e)] assert int(F.asnumpy(neg_src[i])) != pos_map[(neg_d, neg_e)]
exist = neg_edges.edata['exist']
if exclude_positive:
assert np.sum(F.asnumpy(exist) == 0) == len(exist)
else:
assert F.array_equal(g.has_edges_between(neg_src, neg_dst), exist)
# Test the knowledge graph.
for _, neg_edges in EdgeSampler(g, 50,
negative_mode=mode,
neg_sample_size=neg_size,
exclude_positive=exclude_positive,
relations=g.edata['etype']):
neg_lsrc, neg_ldst, neg_leid = neg_edges.all_edges(form='all', order='eid')
neg_src = neg_edges.parent_nid[neg_lsrc]
neg_dst = neg_edges.parent_nid[neg_ldst]
neg_eid = neg_edges.parent_eid[neg_leid]
exists = neg_edges.edata['exist']
neg_edges.edata['etype'] = g.edata['etype'][neg_eid]
for i in range(len(neg_eid)):
u, v = F.asnumpy(neg_src[i]), F.asnumpy(neg_dst[i])
if g.has_edge_between(u, v):
eid = g.edge_id(u, v)
etype = g.edata['etype'][eid]
exist = neg_edges.edata['etype'][i] == etype
assert F.asnumpy(exists[i]) == F.asnumpy(exist)
def test_negative_sampler(): def test_negative_sampler():
check_negative_sampler('head', True) check_negative_sampler('head', True)
check_negative_sampler('PBG-head', False) check_negative_sampler('PBG-head', False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment