You need to sign in or sign up before continuing.
Unverified Commit f8184153 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[KG] reduce memory consumption. (#902)

* reduce memory consumption.

* fix a bug.

* fix a bug.

* fix.
parent 655d7568
import os import os
import numpy as np
def _download_and_extract(url, path, filename): def _download_and_extract(url, path, filename):
import shutil, zipfile import shutil, zipfile
...@@ -71,13 +72,20 @@ class KGDataset1: ...@@ -71,13 +72,20 @@ class KGDataset1:
def read_triple(self, path, mode): def read_triple(self, path, mode):
# mode: train/valid/test # mode: train/valid/test
triples = [] heads = []
tails = []
rels = []
with open(os.path.join(path, '{}.txt'.format(mode))) as f: with open(os.path.join(path, '{}.txt'.format(mode))) as f:
for line in f: for line in f:
h, r, t = line.strip().split('\t') h, r, t = line.strip().split('\t')
triples.append((self.entity2id[h], self.relation2id[r], self.entity2id[t])) heads.append(self.entity2id[h])
rels.append(self.relation2id[r])
tails.append(self.entity2id[t])
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)
return triples return (heads, rels, tails)
class KGDataset2: class KGDataset2:
...@@ -115,16 +123,23 @@ class KGDataset2: ...@@ -115,16 +123,23 @@ class KGDataset2:
self.test = self.read_triple(self.path, 'test') self.test = self.read_triple(self.path, 'test')
def read_triple(self, path, mode, skip_first_line=False): def read_triple(self, path, mode, skip_first_line=False):
triples = [] heads = []
tails = []
rels = []
print('Reading {} triples....'.format(mode)) print('Reading {} triples....'.format(mode))
with open(os.path.join(path, '{}.txt'.format(mode))) as f: with open(os.path.join(path, '{}.txt'.format(mode))) as f:
if skip_first_line: if skip_first_line:
_ = f.readline() _ = f.readline()
for line in f: for line in f:
h, t, r = line.strip().split('\t') h, t, r = line.strip().split('\t')
triples.append((int(h), int(r), int(t))) heads.append(int(h))
print('Finished. Read {} {} triples.'.format(len(triples), mode)) tails.append(int(t))
return triples rels.append(int(r))
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)
print('Finished. Read {} {} triples.'.format(len(heads), mode))
return (heads, rels, tails)
def get_dataset(data_path, data_name, format_str): def get_dataset(data_path, data_name, format_str):
......
...@@ -11,9 +11,9 @@ import time ...@@ -11,9 +11,9 @@ import time
# This partitions a list of edges based on relations to make sure # This partitions a list of edges based on relations to make sure
# each partition has roughly the same number of edges and relations. # each partition has roughly the same number of edges and relations.
def RelationPartition(edges, n): def RelationPartition(edges, n):
print('relation partition {} edges into {} parts'.format(len(edges), n)) heads, rels, tails = edges
rel = np.array([r for h, r, t in edges]) print('relation partition {} edges into {} parts'.format(len(heads), n))
uniq, cnts = np.unique(rel, return_counts=True) uniq, cnts = np.unique(rels, return_counts=True)
idx = np.flip(np.argsort(cnts)) idx = np.flip(np.argsort(cnts))
cnts = cnts[idx] cnts = cnts[idx]
uniq = uniq[idx] uniq = uniq[idx]
...@@ -30,35 +30,39 @@ def RelationPartition(edges, n): ...@@ -30,35 +30,39 @@ def RelationPartition(edges, n):
rel_cnts[idx] += 1 rel_cnts[idx] += 1
for i, edge_cnt in enumerate(edge_cnts): for i, edge_cnt in enumerate(edge_cnts):
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i])) print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))
parts = [] parts = []
for _ in range(n): for _ in range(n):
parts.append([]) parts.append([])
for h, r, t in edges: # let's store the edge index to each partition first.
idx = rel_dict[r] for i, r in enumerate(rels):
parts[idx].append((h, r, t)) part_idx = rel_dict[r]
parts[part_idx].append(i)
for i, part in enumerate(parts):
parts[i] = np.array(part, dtype=np.int64)
return parts return parts
def RandomPartition(edges, n): def RandomPartition(edges, n):
print('random partition {} edges into {} parts'.format(len(edges), n)) heads, rels, tails = edges
idx = np.random.permutation(len(edges)) print('random partition {} edges into {} parts'.format(len(heads), n))
idx = np.random.permutation(len(heads))
part_size = int(math.ceil(len(idx) / n)) part_size = int(math.ceil(len(idx) / n))
parts = [] parts = []
for i in range(n): for i in range(n):
start = part_size * i start = part_size * i
end = min(part_size * (i + 1), len(idx)) end = min(part_size * (i + 1), len(idx))
parts.append([edges[i] for i in idx[start:end]]) parts.append(idx[start:end])
print('part {} has {} edges'.format(i, len(parts[-1])))
return parts return parts
def ConstructGraph(edges, n_entities, i, args): def ConstructGraph(edges, n_entities, args):
pickle_name = 'graph_train_{}.pickle'.format(i) pickle_name = 'graph_train.pickle'
if args.pickle_graph and os.path.exists(os.path.join(args.data_path, args.dataset, pickle_name)): if args.pickle_graph and os.path.exists(os.path.join(args.data_path, args.dataset, pickle_name)):
with open(os.path.join(args.data_path, args.dataset, pickle_name), 'rb') as graph_file: with open(os.path.join(args.data_path, args.dataset, pickle_name), 'rb') as graph_file:
g = pickle.load(graph_file) g = pickle.load(graph_file)
print('Load pickled graph.') print('Load pickled graph.')
else: else:
src = [t[0] for t in edges] src, etype_id, dst = edges
etype_id = [t[1] for t in edges]
dst = [t[2] for t in edges]
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), shape=[n_entities, n_entities]) coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), shape=[n_entities, n_entities])
g = dgl.DGLGraph(coo, readonly=True, sort_csr=True) g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
g.ndata['id'] = F.arange(0, g.number_of_nodes()) g.ndata['id'] = F.arange(0, g.number_of_nodes())
...@@ -71,26 +75,23 @@ def ConstructGraph(edges, n_entities, i, args): ...@@ -71,26 +75,23 @@ def ConstructGraph(edges, n_entities, i, args):
class TrainDataset(object): class TrainDataset(object):
def __init__(self, dataset, args, weighting=False, ranks=64): def __init__(self, dataset, args, weighting=False, ranks=64):
triples = dataset.train triples = dataset.train
print('|Train|:', len(triples)) self.g = ConstructGraph(triples, dataset.n_entities, args)
num_train = len(triples[0])
print('|Train|:', num_train)
if ranks > 1 and args.rel_part: if ranks > 1 and args.rel_part:
triples_list = RelationPartition(triples, ranks) self.edge_parts = RelationPartition(triples, ranks)
elif ranks > 1: elif ranks > 1:
triples_list = RandomPartition(triples, ranks) self.edge_parts = RandomPartition(triples, ranks)
else: else:
triples_list = [triples] self.edge_parts = [np.arange(num_train)]
self.graphs = [] if weighting:
for i, triples in enumerate(triples_list): # TODO: weight to be added
g = ConstructGraph(triples, dataset.n_entities, i, args) count = self.count_freq(triples)
if weighting: subsampling_weight = np.vectorize(
# TODO: weight to be added lambda h, r, t: np.sqrt(1 / (count[(h, r)] + count[(t, -r - 1)]))
count = self.count_freq(triples) )
subsampling_weight = np.vectorize( weight = subsampling_weight(src, etype_id, dst)
lambda h, r, t: np.sqrt(1 / (count[(h, r)] + count[(t, -r - 1)])) self.g.edata['weight'] = F.zerocopy_from_numpy(weight)
)
weight = subsampling_weight(src, etype_id, dst)
g.edata['weight'] = F.zerocopy_from_numpy(weight)
# to be added
self.graphs.append(g)
def count_freq(self, triples, start=4): def count_freq(self, triples, start=4):
count = {} count = {}
...@@ -109,7 +110,8 @@ class TrainDataset(object): ...@@ -109,7 +110,8 @@ class TrainDataset(object):
def create_sampler(self, batch_size, neg_sample_size=2, mode='head', num_workers=5, def create_sampler(self, batch_size, neg_sample_size=2, mode='head', num_workers=5,
shuffle=True, exclude_positive=False, rank=0): shuffle=True, exclude_positive=False, rank=0):
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler') EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
return EdgeSampler(self.graphs[rank], return EdgeSampler(self.g,
seed_edges=F.tensor(self.edge_parts[rank]),
batch_size=batch_size, batch_size=batch_size,
neg_sample_size=neg_sample_size, neg_sample_size=neg_sample_size,
negative_mode=mode, negative_mode=mode,
...@@ -118,6 +120,7 @@ class TrainDataset(object): ...@@ -118,6 +120,7 @@ class TrainDataset(object):
exclude_positive=exclude_positive, exclude_positive=exclude_positive,
return_false_neg=False) return_false_neg=False)
class PBGNegEdgeSubgraph(dgl.subgraph.DGLSubGraph): class PBGNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
def __init__(self, subg, num_chunks, chunk_size, def __init__(self, subg, num_chunks, chunk_size,
neg_sample_size, neg_head): neg_sample_size, neg_head):
...@@ -203,17 +206,17 @@ class EvalSampler(object): ...@@ -203,17 +206,17 @@ class EvalSampler(object):
class EvalDataset(object): class EvalDataset(object):
def __init__(self, dataset, args): def __init__(self, dataset, args):
triples = dataset.train + dataset.valid + dataset.test
pickle_name = 'graph_all.pickle' pickle_name = 'graph_all.pickle'
if args.pickle_graph and os.path.exists(os.path.join(args.data_path, args.dataset, pickle_name)): if args.pickle_graph and os.path.exists(os.path.join(args.data_path, args.dataset, pickle_name)):
with open(os.path.join(args.data_path, args.dataset, pickle_name), 'rb') as graph_file: with open(os.path.join(args.data_path, args.dataset, pickle_name), 'rb') as graph_file:
g = pickle.load(graph_file) g = pickle.load(graph_file)
print('Load pickled graph.') print('Load pickled graph.')
else: else:
src = [t[0] for t in triples] src = np.concatenate((dataset.train[0], dataset.valid[0], dataset.test[0]))
etype_id = [t[1] for t in triples] etype_id = np.concatenate((dataset.train[1], dataset.valid[1], dataset.test[1]))
dst = [t[2] for t in triples] dst = np.concatenate((dataset.train[2], dataset.valid[2], dataset.test[2]))
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), shape=[dataset.n_entities, dataset.n_entities]) coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)),
shape=[dataset.n_entities, dataset.n_entities])
g = dgl.DGLGraph(coo, readonly=True, sort_csr=True) g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
g.ndata['id'] = F.arange(0, g.number_of_nodes()) g.ndata['id'] = F.arange(0, g.number_of_nodes())
g.edata['id'] = F.tensor(etype_id, F.int64) g.edata['id'] = F.tensor(etype_id, F.int64)
...@@ -222,9 +225,9 @@ class EvalDataset(object): ...@@ -222,9 +225,9 @@ class EvalDataset(object):
pickle.dump(g, graph_file) pickle.dump(g, graph_file)
self.g = g self.g = g
self.num_train = len(dataset.train) self.num_train = len(dataset.train[0])
self.num_valid = len(dataset.valid) self.num_valid = len(dataset.valid[0])
self.num_test = len(dataset.test) self.num_test = len(dataset.test[0])
if args.eval_percent < 1: if args.eval_percent < 1:
self.valid = np.random.randint(0, self.num_valid, self.valid = np.random.randint(0, self.num_valid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment