Commit cd9599ab authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by Da Zheng
Browse files

[KG][Optimization] Remove copy from parent in minibatch generation (#1193)

* remote copy from parent

* upd

* make test_sampler easier to pass
parent 28a24414
...@@ -70,8 +70,7 @@ def ConstructGraph(edges, n_entities, args): ...@@ -70,8 +70,7 @@ def ConstructGraph(edges, n_entities, args):
src, etype_id, dst = edges src, etype_id, dst = edges
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), shape=[n_entities, n_entities]) coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), shape=[n_entities, n_entities])
g = dgl.DGLGraph(coo, readonly=True, sort_csr=True) g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
g.ndata['id'] = F.arange(0, g.number_of_nodes()) g.edata['tid'] = F.tensor(etype_id, F.int64)
g.edata['id'] = F.tensor(etype_id, F.int64)
if args.pickle_graph: if args.pickle_graph:
with open(os.path.join(args.data_path, args.dataset, pickle_name), 'wb') as graph_file: with open(os.path.join(args.data_path, args.dataset, pickle_name), 'wb') as graph_file:
pickle.dump(g, graph_file) pickle.dump(g, graph_file)
...@@ -189,7 +188,7 @@ class EvalSampler(object): ...@@ -189,7 +188,7 @@ class EvalSampler(object):
num_workers=num_workers, num_workers=num_workers,
shuffle=False, shuffle=False,
exclude_positive=False, exclude_positive=False,
relations=g.edata['id'], relations=g.edata['tid'],
return_false_neg=filter_false_neg) return_false_neg=filter_false_neg)
self.sampler_iter = iter(self.sampler) self.sampler_iter = iter(self.sampler)
self.mode = mode self.mode = mode
...@@ -211,8 +210,9 @@ class EvalSampler(object): ...@@ -211,8 +210,9 @@ class EvalSampler(object):
if neg_g is not None: if neg_g is not None:
break break
pos_g.copy_from_parent() pos_g.ndata['id'] = pos_g.parent_nid
neg_g.copy_from_parent() neg_g.ndata['id'] = neg_g.parent_nid
pos_g.edata['id'] = pos_g._parent.edata['tid'][pos_g.parent_eid]
if self.filter_false_neg: if self.filter_false_neg:
neg_g.edata['bias'] = F.astype(-neg_positive, F.float32) neg_g.edata['bias'] = F.astype(-neg_positive, F.float32)
return pos_g, neg_g return pos_g, neg_g
...@@ -235,13 +235,11 @@ class EvalDataset(object): ...@@ -235,13 +235,11 @@ class EvalDataset(object):
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)),
shape=[dataset.n_entities, dataset.n_entities]) shape=[dataset.n_entities, dataset.n_entities])
g = dgl.DGLGraph(coo, readonly=True, sort_csr=True) g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
g.ndata['id'] = F.arange(0, g.number_of_nodes()) g.edata['tid'] = F.tensor(etype_id, F.int64)
g.edata['id'] = F.tensor(etype_id, F.int64)
if args.pickle_graph: if args.pickle_graph:
with open(os.path.join(args.data_path, args.dataset, pickle_name), 'wb') as graph_file: with open(os.path.join(args.data_path, args.dataset, pickle_name), 'wb') as graph_file:
pickle.dump(g, graph_file) pickle.dump(g, graph_file)
self.g = g self.g = g
self.num_train = len(dataset.train[0]) self.num_train = len(dataset.train[0])
self.num_valid = len(dataset.valid[0]) self.num_valid = len(dataset.valid[0])
self.num_test = len(dataset.test[0]) self.num_test = len(dataset.test[0])
...@@ -329,6 +327,7 @@ class NewBidirectionalOneShotIterator: ...@@ -329,6 +327,7 @@ class NewBidirectionalOneShotIterator:
if neg_g is None: if neg_g is None:
continue continue
pos_g.copy_from_parent() pos_g.ndata['id'] = pos_g.parent_nid
neg_g.copy_from_parent() neg_g.ndata['id'] = neg_g.parent_nid
pos_g.edata['id'] = pos_g._parent.edata['tid'][pos_g.parent_eid]
yield pos_g, neg_g yield pos_g, neg_g
...@@ -670,7 +670,7 @@ def check_positive_edge_sampler(): ...@@ -670,7 +670,7 @@ def check_positive_edge_sampler():
num_edges = g.number_of_edges() num_edges = g.number_of_edges()
edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 1, dtype=np.float32)), F.cpu()) edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 1, dtype=np.float32)), F.cpu())
edge_weight[num_edges-1] = num_edges ** 2 edge_weight[num_edges-1] = num_edges ** 3
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler') EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
# Correctness check # Correctness check
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment