Commit cd9599ab authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Da Zheng
Browse files

[KG][Optimization] Remove copy from parent in minibatch generation (#1193)

* remote copy from parent

* upd

* make test_sampler easier to pass
parent 28a24414
......@@ -70,8 +70,7 @@ def ConstructGraph(edges, n_entities, args):
src, etype_id, dst = edges
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), shape=[n_entities, n_entities])
g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
g.ndata['id'] = F.arange(0, g.number_of_nodes())
g.edata['id'] = F.tensor(etype_id, F.int64)
g.edata['tid'] = F.tensor(etype_id, F.int64)
if args.pickle_graph:
with open(os.path.join(args.data_path, args.dataset, pickle_name), 'wb') as graph_file:
pickle.dump(g, graph_file)
......@@ -189,7 +188,7 @@ class EvalSampler(object):
num_workers=num_workers,
shuffle=False,
exclude_positive=False,
relations=g.edata['id'],
relations=g.edata['tid'],
return_false_neg=filter_false_neg)
self.sampler_iter = iter(self.sampler)
self.mode = mode
......@@ -211,8 +210,9 @@ class EvalSampler(object):
if neg_g is not None:
break
pos_g.copy_from_parent()
neg_g.copy_from_parent()
pos_g.ndata['id'] = pos_g.parent_nid
neg_g.ndata['id'] = neg_g.parent_nid
pos_g.edata['id'] = pos_g._parent.edata['tid'][pos_g.parent_eid]
if self.filter_false_neg:
neg_g.edata['bias'] = F.astype(-neg_positive, F.float32)
return pos_g, neg_g
......@@ -235,13 +235,11 @@ class EvalDataset(object):
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)),
shape=[dataset.n_entities, dataset.n_entities])
g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
g.ndata['id'] = F.arange(0, g.number_of_nodes())
g.edata['id'] = F.tensor(etype_id, F.int64)
g.edata['tid'] = F.tensor(etype_id, F.int64)
if args.pickle_graph:
with open(os.path.join(args.data_path, args.dataset, pickle_name), 'wb') as graph_file:
pickle.dump(g, graph_file)
self.g = g
self.num_train = len(dataset.train[0])
self.num_valid = len(dataset.valid[0])
self.num_test = len(dataset.test[0])
......@@ -329,6 +327,7 @@ class NewBidirectionalOneShotIterator:
if neg_g is None:
continue
pos_g.copy_from_parent()
neg_g.copy_from_parent()
pos_g.ndata['id'] = pos_g.parent_nid
neg_g.ndata['id'] = neg_g.parent_nid
pos_g.edata['id'] = pos_g._parent.edata['tid'][pos_g.parent_eid]
yield pos_g, neg_g
......@@ -670,7 +670,7 @@ def check_positive_edge_sampler():
num_edges = g.number_of_edges()
edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 1, dtype=np.float32)), F.cpu())
edge_weight[num_edges-1] = num_edges ** 2
edge_weight[num_edges-1] = num_edges ** 3
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
# Correctness check
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment