"git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "6c00b338c54b17bfd5343a4eabcb4d0df160764e"
Unverified Commit 1db697ec authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

Improve edge sampler (#881)

* optimizer for sampling all negative edges.

* allow to disable checking false negative.

* fix lint.

* fix docstring.

* fix for comment.

* add comments.
parent 2d489617
...@@ -454,9 +454,26 @@ class EdgeSampler(object): ...@@ -454,9 +454,26 @@ class EdgeSampler(object):
When negative edges are created, a batch of negative edges are also placed When negative edges are created, a batch of negative edges are also placed
in a subgraph. in a subgraph.
Currently, negative_mode only supports only 'head' and 'tail'. Currently, negative_mode only supports:
If negative_mode=='head', the negative edges are generated by corrupting 'head': the negative edges are generated by corrupting head nodes
head nodes; otherwise, the tail nodes are corrupted. with uniformly randomly sampled nodes,
'tail': the negative edges are generated by corrupting tail nodes
with uniformly randomly sampled nodes,
'PBG-head': the negative edges are generated by corrupting a set
of head nodes with the same set of nodes uniformly randomly sampled
from the graph. Please see Pytorch-BigGraph for more details.
'PBG-tail': the negative edges are generated by corrupting a set
of tail nodes with the same set of nodes similar to 'PBG-head'.
When the flag return_false_neg is turned on, the sampler will also check
if the generated negative edges are true negative edges and will return
a vector that indicates false negative edges. The vector is stored in
the negative graph as `false_neg` edge data.
When checking false negative edges, a user can provide edge relations
for a knowledge graph. A negative edge is considered as a false negative
edge only if the triple (source node, destination node and relation)
matches one of the edges in the graph.
Parameters Parameters
---------- ----------
...@@ -464,20 +481,24 @@ class EdgeSampler(object): ...@@ -464,20 +481,24 @@ class EdgeSampler(object):
The DGLGraph where we sample edges. The DGLGraph where we sample edges.
batch_size : int batch_size : int
The batch size (i.e, the number of edges from the graph) The batch size (i.e, the number of edges from the graph)
seed_edges : tensor seed_edges : tensor, optional
A list of edges where we sample from. A list of edges where we sample from.
shuffle : bool shuffle : bool, optional
whether randomly shuffle the list of edges where we sample from. whether randomly shuffle the list of edges where we sample from.
num_workers : int num_workers : int, optional
The number of workers to sample edges in parallel. The number of workers to sample edges in parallel.
prefetch : bool, optional prefetch : bool, optional
If true, prefetch the samples in the next batch. Default: False If true, prefetch the samples in the next batch. Default: False
negative_mode : string negative_mode : string, optional
The method used to construct negative edges. Possible values are 'head', 'tail'. The method used to construct negative edges. Possible values are 'head', 'tail'.
neg_sample_size : int neg_sample_size : int, optional
The number of negative edges to sample for each edge. The number of negative edges to sample for each edge.
exclude_positive : int exclude_positive : int, optional
Whether to exclude positive edges from the negative edges. Whether to exclude positive edges from the negative edges.
return_false_neg: bool, optional
Whether to calculate false negative edges and return them as edge data in negative graphs.
relations: tensor, optional
relations of the edges if this is a knowledge graph.
Class properties Class properties
---------------- ----------------
...@@ -498,6 +519,7 @@ class EdgeSampler(object): ...@@ -498,6 +519,7 @@ class EdgeSampler(object):
negative_mode="", negative_mode="",
neg_sample_size=0, neg_sample_size=0,
exclude_positive=False, exclude_positive=False,
return_false_neg=False,
relations=None): relations=None):
self._g = g self._g = g
if self.immutable_only and not g._graph.is_readonly(): if self.immutable_only and not g._graph.is_readonly():
...@@ -511,6 +533,7 @@ class EdgeSampler(object): ...@@ -511,6 +533,7 @@ class EdgeSampler(object):
assert g.number_of_edges() == len(relations) assert g.number_of_edges() == len(relations)
self._relations = relations self._relations = relations
self._return_false_neg = return_false_neg
self._batch_size = int(batch_size) self._batch_size = int(batch_size)
if seed_edges is None: if seed_edges is None:
...@@ -555,6 +578,7 @@ class EdgeSampler(object): ...@@ -555,6 +578,7 @@ class EdgeSampler(object):
self._negative_mode, self._negative_mode,
self._neg_sample_size, self._neg_sample_size,
self._exclude_positive, self._exclude_positive,
self._return_false_neg,
self._relations) self._relations)
if len(subgs) == 0: if len(subgs) == 0:
...@@ -570,8 +594,9 @@ class EdgeSampler(object): ...@@ -570,8 +594,9 @@ class EdgeSampler(object):
for i in range(num_pos): for i in range(num_pos):
pos_subg = subgraph.DGLSubGraph(self.g, subgs[i]) pos_subg = subgraph.DGLSubGraph(self.g, subgs[i])
neg_subg = subgraph.DGLSubGraph(self.g, subgs[i + num_pos]) neg_subg = subgraph.DGLSubGraph(self.g, subgs[i + num_pos])
exist = _CAPI_GetNegEdgeExistence(subgs[i + num_pos]); if self._return_false_neg:
neg_subg.edata['exist'] = utils.toindex(exist).tousertensor() exist = _CAPI_GetNegEdgeExistence(subgs[i + num_pos]);
neg_subg.edata['false_neg'] = utils.toindex(exist).tousertensor()
rets.append((pos_subg, neg_subg)) rets.append((pos_subg, neg_subg))
return rets return rets
......
...@@ -107,12 +107,19 @@ class ArrayHeap { ...@@ -107,12 +107,19 @@ class ArrayHeap {
* Uniformly sample integers from [0, set_size) without replacement. * Uniformly sample integers from [0, set_size) without replacement.
*/ */
void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out) { void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out) {
std::unordered_set<size_t> sampled_idxs;
while (sampled_idxs.size() < num) {
sampled_idxs.insert(RandomEngine::ThreadLocal()->RandInt(set_size));
}
out->clear(); out->clear();
out->insert(out->end(), sampled_idxs.begin(), sampled_idxs.end()); if (num < set_size) {
std::unordered_set<size_t> sampled_idxs;
while (sampled_idxs.size() < num) {
sampled_idxs.insert(RandomEngine::ThreadLocal()->RandInt(set_size));
}
out->insert(out->end(), sampled_idxs.begin(), sampled_idxs.end());
} else {
// If we need to sample all elements in the set, we don't need to
// generate random numbers.
for (size_t i = 0; i < set_size; i++)
out->push_back(i);
}
} }
void RandomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude, void RandomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude,
...@@ -121,14 +128,25 @@ void RandomSample(size_t set_size, size_t num, const std::vector<size_t> &exclud ...@@ -121,14 +128,25 @@ void RandomSample(size_t set_size, size_t num, const std::vector<size_t> &exclud
for (auto v : exclude) { for (auto v : exclude) {
sampled_idxs.insert(std::pair<size_t, int>(v, 0)); sampled_idxs.insert(std::pair<size_t, int>(v, 0));
} }
while (sampled_idxs.size() < num + exclude.size()) {
size_t rand = RandomEngine::ThreadLocal()->RandInt(set_size);
sampled_idxs.insert(std::pair<size_t, int>(rand, 1));
}
out->clear(); out->clear();
for (auto it = sampled_idxs.begin(); it != sampled_idxs.end(); it++) { if (num + exclude.size() < set_size) {
if (it->second) { while (sampled_idxs.size() < num + exclude.size()) {
out->push_back(it->first); size_t rand = RandomEngine::ThreadLocal()->RandInt(set_size);
sampled_idxs.insert(std::pair<size_t, int>(rand, 1));
}
for (auto it = sampled_idxs.begin(); it != sampled_idxs.end(); it++) {
if (it->second) {
out->push_back(it->first);
}
}
} else {
// If we need to sample all elements in the set, we don't need to
// generate random numbers.
for (size_t i = 0; i < set_size; i++) {
// If the element doesn't exist in exclude.
if (sampled_idxs.find(i) == sampled_idxs.end()) {
out->push_back(i);
}
} }
} }
} }
...@@ -903,7 +921,7 @@ dgl_id_t global2local_map(dgl_id_t global_id, ...@@ -903,7 +921,7 @@ dgl_id_t global2local_map(dgl_id_t global_id,
} }
} }
inline bool is_neg_head_mode(const std::string &mode) { inline bool IsNegativeHeadMode(const std::string &mode) {
return mode == "head"; return mode == "head";
} }
...@@ -963,7 +981,8 @@ IdArray CheckExistence(GraphPtr gptr, IdArray relations, ...@@ -963,7 +981,8 @@ IdArray CheckExistence(GraphPtr gptr, IdArray relations,
NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &pos_subg, NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &pos_subg,
const std::string &neg_mode, const std::string &neg_mode,
int neg_sample_size, bool exclude_positive) { int neg_sample_size, bool exclude_positive,
bool check_false_neg) {
int64_t num_tot_nodes = gptr->NumVertices(); int64_t num_tot_nodes = gptr->NumVertices();
bool is_multigraph = gptr->IsMultigraph(); bool is_multigraph = gptr->IsMultigraph();
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo"); std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
...@@ -988,40 +1007,64 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po ...@@ -988,40 +1007,64 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po
dgl_id_t *neg_eid_data = static_cast<dgl_id_t *>(neg_eid->data); dgl_id_t *neg_eid_data = static_cast<dgl_id_t *>(neg_eid->data);
dgl_id_t *induced_neg_eid_data = static_cast<dgl_id_t *>(induced_neg_eid->data); dgl_id_t *induced_neg_eid_data = static_cast<dgl_id_t *>(induced_neg_eid->data);
const dgl_id_t *unchanged;
dgl_id_t *neg_unchanged;
dgl_id_t *neg_changed;
if (IsNegativeHeadMode(neg_mode)) {
unchanged = dst_data;
neg_unchanged = neg_dst_data;
neg_changed = neg_src_data;
} else {
unchanged = src_data;
neg_unchanged = neg_src_data;
neg_changed = neg_dst_data;
}
dgl_id_t curr_eid = 0; dgl_id_t curr_eid = 0;
std::vector<size_t> neg_vids; std::vector<size_t> neg_vids;
neg_vids.reserve(neg_sample_size); neg_vids.reserve(neg_sample_size);
std::unordered_map<dgl_id_t, dgl_id_t> neg_map; std::unordered_map<dgl_id_t, dgl_id_t> neg_map;
// If we don't exclude positive edges, we are actually sampling more than
// the total number of nodes in the graph.
if (!exclude_positive && neg_sample_size >= num_tot_nodes) {
// We add all nodes as negative nodes.
for (int64_t i = 0; i < num_tot_nodes; i++) {
neg_vids.push_back(i);
neg_map[i] = i;
}
}
for (int64_t i = 0; i < num_pos_edges; i++) { for (int64_t i = 0; i < num_pos_edges; i++) {
size_t neg_idx = i * neg_sample_size; size_t neg_idx = i * neg_sample_size;
neg_vids.clear();
std::vector<size_t> neighbors; std::vector<size_t> neighbors;
DGLIdIters neigh_it; DGLIdIters neigh_it;
const dgl_id_t *unchanged; if (IsNegativeHeadMode(neg_mode)) {
dgl_id_t *neg_unchanged;
dgl_id_t *neg_changed;
if (is_neg_head_mode(neg_mode)) {
unchanged = dst_data;
neg_unchanged = neg_dst_data;
neg_changed = neg_src_data;
neigh_it = gptr->PredVec(induced_vid_data[unchanged[i]]); neigh_it = gptr->PredVec(induced_vid_data[unchanged[i]]);
} else { } else {
unchanged = src_data;
neg_unchanged = neg_src_data;
neg_changed = neg_dst_data;
neigh_it = gptr->SuccVec(induced_vid_data[unchanged[i]]); neigh_it = gptr->SuccVec(induced_vid_data[unchanged[i]]);
} }
if (exclude_positive) { // If the number of negative nodes is smaller than the number of total nodes
// in the graph.
if (exclude_positive && neg_sample_size < num_tot_nodes) {
std::vector<size_t> exclude; std::vector<size_t> exclude;
for (auto it = neigh_it.begin(); it != neigh_it.end(); it++) { for (auto it = neigh_it.begin(); it != neigh_it.end(); it++) {
dgl_id_t global_vid = *it; dgl_id_t global_vid = *it;
exclude.push_back(global_vid); exclude.push_back(global_vid);
} }
neg_vids.clear();
RandomSample(num_tot_nodes, neg_sample_size, exclude, &neg_vids); RandomSample(num_tot_nodes, neg_sample_size, exclude, &neg_vids);
} else { } else if (neg_sample_size < num_tot_nodes) {
neg_vids.clear();
RandomSample(num_tot_nodes, neg_sample_size, &neg_vids); RandomSample(num_tot_nodes, neg_sample_size, &neg_vids);
} else if (exclude_positive) {
LOG(FATAL) << "We can't exclude positive edges when sampling negative edges with all nodes.";
} else {
// We don't need to do anything here.
// In this case, every edge has the same negative edges. That is,
// neg_vids contains all nodes of the graph. They have been generated
// before the for loop.
} }
dgl_id_t global_unchanged = induced_vid_data[unchanged[i]]; dgl_id_t global_unchanged = induced_vid_data[unchanged[i]];
...@@ -1053,11 +1096,13 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po ...@@ -1053,11 +1096,13 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po
neg_subg.induced_vertices = induced_neg_vid; neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid; neg_subg.induced_edges = induced_neg_eid;
// TODO(zhengda) we should provide an array of 1s if exclude_positive // TODO(zhengda) we should provide an array of 1s if exclude_positive
if (relations->shape[0] == 0) { if (check_false_neg) {
neg_subg.exist = CheckExistence(gptr, neg_src, neg_dst, induced_neg_vid); if (relations->shape[0] == 0) {
} else { neg_subg.exist = CheckExistence(gptr, neg_src, neg_dst, induced_neg_vid);
neg_subg.exist = CheckExistence(gptr, relations, neg_src, neg_dst, } else {
induced_neg_vid, induced_neg_eid); neg_subg.exist = CheckExistence(gptr, relations, neg_src, neg_dst,
induced_neg_vid, induced_neg_eid);
}
} }
return neg_subg; return neg_subg;
} }
...@@ -1065,7 +1110,7 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po ...@@ -1065,7 +1110,7 @@ NegSubgraph NegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &po
NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &pos_subg, NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph &pos_subg,
const std::string &neg_mode, const std::string &neg_mode,
int neg_sample_size, bool is_multigraph, int neg_sample_size, bool is_multigraph,
bool exclude_positive) { bool exclude_positive, bool check_false_neg) {
int64_t num_tot_nodes = gptr->NumVertices(); int64_t num_tot_nodes = gptr->NumVertices();
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo"); std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0]; IdArray coo = adj[0];
...@@ -1111,7 +1156,7 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph ...@@ -1111,7 +1156,7 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph
dgl_id_t *neg_changed; dgl_id_t *neg_changed;
// corrupt head nodes. // corrupt head nodes.
if (is_neg_head_mode(neg_mode)) { if (IsNegativeHeadMode(neg_mode)) {
unchanged = dst_data; unchanged = dst_data;
neg_unchanged = neg_dst_data; neg_unchanged = neg_dst_data;
neg_changed = neg_src_data; neg_changed = neg_src_data;
...@@ -1176,11 +1221,13 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph ...@@ -1176,11 +1221,13 @@ NegSubgraph PBGNegEdgeSubgraph(GraphPtr gptr, IdArray relations, const Subgraph
neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo)); neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));
neg_subg.induced_vertices = induced_neg_vid; neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid; neg_subg.induced_edges = induced_neg_eid;
if (relations->shape[0] == 0) { if (check_false_neg) {
neg_subg.exist = CheckExistence(gptr, neg_src, neg_dst, induced_neg_vid); if (relations->shape[0] == 0) {
} else { neg_subg.exist = CheckExistence(gptr, neg_src, neg_dst, induced_neg_vid);
neg_subg.exist = CheckExistence(gptr, relations, neg_src, neg_dst, } else {
induced_neg_vid, induced_neg_eid); neg_subg.exist = CheckExistence(gptr, relations, neg_src, neg_dst,
induced_neg_vid, induced_neg_eid);
}
} }
return neg_subg; return neg_subg;
} }
...@@ -1206,7 +1253,8 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling") ...@@ -1206,7 +1253,8 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
const std::string neg_mode = args[5]; const std::string neg_mode = args[5];
const int neg_sample_size = args[6]; const int neg_sample_size = args[6];
const bool exclude_positive = args[7]; const bool exclude_positive = args[7];
IdArray relations = args[8]; const bool check_false_neg = args[8];
IdArray relations = args[9];
// process args // process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
...@@ -1240,11 +1288,12 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling") ...@@ -1240,11 +1288,12 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
if (neg_mode.substr(0, 3) == "PBG") { if (neg_mode.substr(0, 3) == "PBG") {
NegSubgraph neg_subg = PBGNegEdgeSubgraph(gptr, relations, subg, NegSubgraph neg_subg = PBGNegEdgeSubgraph(gptr, relations, subg,
neg_mode.substr(4), neg_sample_size, neg_mode.substr(4), neg_sample_size,
gptr->IsMultigraph(), exclude_positive); gptr->IsMultigraph(), exclude_positive,
check_false_neg);
negative_subgs[i] = ConvertRef(neg_subg); negative_subgs[i] = ConvertRef(neg_subg);
} else if (neg_mode.size() > 0) { } else if (neg_mode.size() > 0) {
NegSubgraph neg_subg = NegEdgeSubgraph(gptr, relations, subg, neg_mode, neg_sample_size, NegSubgraph neg_subg = NegEdgeSubgraph(gptr, relations, subg, neg_mode, neg_sample_size,
exclude_positive); exclude_positive, check_false_neg);
negative_subgs[i] = ConvertRef(neg_subg); negative_subgs[i] = ConvertRef(neg_subg);
} }
} }
......
...@@ -220,7 +220,7 @@ def test_setseed(): ...@@ -220,7 +220,7 @@ def test_setseed():
g, 5, 3, num_hops=2, neighbor_type='in', num_workers=4)): g, 5, 3, num_hops=2, neighbor_type='in', num_workers=4)):
pass pass
def check_negative_sampler(mode, exclude_positive): def check_negative_sampler(mode, exclude_positive, neg_size):
g = generate_rand_graph(100) g = generate_rand_graph(100)
etype = np.random.randint(0, 10, size=g.number_of_edges(), dtype=np.int64) etype = np.random.randint(0, 10, size=g.number_of_edges(), dtype=np.int64)
g.edata['etype'] = F.tensor(etype) g.edata['etype'] = F.tensor(etype)
...@@ -233,12 +233,12 @@ def check_negative_sampler(mode, exclude_positive): ...@@ -233,12 +233,12 @@ def check_negative_sampler(mode, exclude_positive):
pos_map[(pos_d, pos_e)] = int(F.asnumpy(pos_gsrc[i])) pos_map[(pos_d, pos_e)] = int(F.asnumpy(pos_gsrc[i]))
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler') EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
neg_size = 10
# Test the homogeneous graph. # Test the homogeneous graph.
for pos_edges, neg_edges in EdgeSampler(g, 50, for pos_edges, neg_edges in EdgeSampler(g, 50,
negative_mode=mode, negative_mode=mode,
neg_sample_size=neg_size, neg_sample_size=neg_size,
exclude_positive=exclude_positive): exclude_positive=exclude_positive,
return_false_neg=True):
pos_lsrc, pos_ldst, pos_leid = pos_edges.all_edges(form='all', order='eid') pos_lsrc, pos_ldst, pos_leid = pos_edges.all_edges(form='all', order='eid')
assert_array_equal(F.asnumpy(pos_edges.parent_eid[pos_leid]), assert_array_equal(F.asnumpy(pos_edges.parent_eid[pos_leid]),
F.asnumpy(g.edge_ids(pos_edges.parent_nid[pos_lsrc], F.asnumpy(g.edge_ids(pos_edges.parent_nid[pos_lsrc],
...@@ -255,7 +255,7 @@ def check_negative_sampler(mode, exclude_positive): ...@@ -255,7 +255,7 @@ def check_negative_sampler(mode, exclude_positive):
if exclude_positive: if exclude_positive:
assert int(F.asnumpy(neg_src[i])) != pos_map[(neg_d, neg_e)] assert int(F.asnumpy(neg_src[i])) != pos_map[(neg_d, neg_e)]
exist = neg_edges.edata['exist'] exist = neg_edges.edata['false_neg']
if exclude_positive: if exclude_positive:
assert np.sum(F.asnumpy(exist) == 0) == len(exist) assert np.sum(F.asnumpy(exist) == 0) == len(exist)
else: else:
...@@ -266,12 +266,13 @@ def check_negative_sampler(mode, exclude_positive): ...@@ -266,12 +266,13 @@ def check_negative_sampler(mode, exclude_positive):
negative_mode=mode, negative_mode=mode,
neg_sample_size=neg_size, neg_sample_size=neg_size,
exclude_positive=exclude_positive, exclude_positive=exclude_positive,
relations=g.edata['etype']): relations=g.edata['etype'],
return_false_neg=True):
neg_lsrc, neg_ldst, neg_leid = neg_edges.all_edges(form='all', order='eid') neg_lsrc, neg_ldst, neg_leid = neg_edges.all_edges(form='all', order='eid')
neg_src = neg_edges.parent_nid[neg_lsrc] neg_src = neg_edges.parent_nid[neg_lsrc]
neg_dst = neg_edges.parent_nid[neg_ldst] neg_dst = neg_edges.parent_nid[neg_ldst]
neg_eid = neg_edges.parent_eid[neg_leid] neg_eid = neg_edges.parent_eid[neg_leid]
exists = neg_edges.edata['exist'] exists = neg_edges.edata['false_neg']
neg_edges.edata['etype'] = g.edata['etype'][neg_eid] neg_edges.edata['etype'] = g.edata['etype'][neg_eid]
for i in range(len(neg_eid)): for i in range(len(neg_eid)):
u, v = F.asnumpy(neg_src[i]), F.asnumpy(neg_dst[i]) u, v = F.asnumpy(neg_src[i]), F.asnumpy(neg_dst[i])
...@@ -282,8 +283,10 @@ def check_negative_sampler(mode, exclude_positive): ...@@ -282,8 +283,10 @@ def check_negative_sampler(mode, exclude_positive):
assert F.asnumpy(exists[i]) == F.asnumpy(exist) assert F.asnumpy(exists[i]) == F.asnumpy(exist)
def test_negative_sampler(): def test_negative_sampler():
check_negative_sampler('head', True) check_negative_sampler('PBG-head', False, 10)
check_negative_sampler('PBG-head', False) check_negative_sampler('head', True, 10)
check_negative_sampler('head', False, 10)
check_negative_sampler('head', False, 100)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment