eval.py 9.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
from dataloader import EvalDataset, TrainDataset
from dataloader import get_dataset

import argparse
import os
import logging
import time
import pickle

Da Zheng's avatar
Da Zheng committed
10
backend = os.environ.get('DGLBACKEND', 'pytorch')
11
if backend.lower() == 'mxnet':
VoVAllen's avatar
VoVAllen committed
12
    import multiprocessing as mp
13
14
15
    from train_mxnet import load_model_from_checkpoint
    from train_mxnet import test
else:
VoVAllen's avatar
VoVAllen committed
16
    import torch.multiprocessing as mp
17
    from train_pytorch import load_model_from_checkpoint
18
    from train_pytorch import test, test_mp
19
20
21
22
23
24

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
25
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransH', 'TransR', 'TransD',
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE'],
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
        self.add_argument('--format', type=str, default='1',
                          help='the format of the dataset.')
        self.add_argument('--model_path', type=str, default='ckpts',
                          help='the place where models are saved')

        self.add_argument('--batch_size', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=-1,
                          help='negative sampling size for testing')
41
42
43
44
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sampling proportional to vertex degree for testing')
        self.add_argument('--neg_chunk_size', type=int, default=-1,
                          help='chunk size of the negative edges.')
45
46
47
48
49
50
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
51
52
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
53

54
55
        self.add_argument('--gpu', type=int, default=[-1], nargs='+',
                          help='a list of active gpu ids, e.g. 0')
56
57
58
59
60
61
62
63
64
65
66
67
68
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('--seed', type=int, default=0,
                          help='set random seed fro reproducibility')

        self.add_argument('--num_worker', type=int, default=16,
                          help='number of workers used for loading data')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')
69
70
        self.add_argument('--num_thread', type=int, default=1,
                          help='number of thread used')
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    def parse_args(self):
        args = super().parse_args()
        return args

def get_logger(args):
    if not os.path.exists(args.model_path):
        raise Exception('No existing model_path: ' + args.model_path)

    log_file = os.path.join(args.model_path, 'eval.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger

def main(args):
95
96
97
98
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

99
100
101
102
103
104
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    args.pickle_graph = False
    args.train = False
    args.valid = False
    args.test = True
105
    args.strict_rel_part = False
106
    args.soft_rel_part = False
107
    args.async_update = False
108
109
110
111
112
113
    args.batch_size_eval = args.batch_size

    logger = get_logger(args)
    # Here we want to use the regualr negative sampler because we need to ensure that
    # all positive edges are excluded.
    eval_dataset = EvalDataset(dataset, args)
114

115
    args.neg_sample_size_test = args.neg_sample_size
116
    args.neg_deg_sample_eval = args.neg_deg_sample
117
118
    if args.neg_sample_size < 0:
        args.neg_sample_size_test = args.neg_sample_size = eval_dataset.g.number_of_nodes()
119
120
121
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size

Da Zheng's avatar
Da Zheng committed
122
123
124
125
126
    num_workers = args.num_worker
    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
127
128
129
130
131
132
    if args.num_proc > 1:
        test_sampler_tails = []
        test_sampler_heads = []
        for i in range(args.num_proc):
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
                                                            args.neg_sample_size,
133
                                                            args.neg_chunk_size,
134
                                                            args.eval_filter,
135
                                                            mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
136
                                                            num_workers=num_workers,
137
138
139
                                                            rank=i, ranks=args.num_proc)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
                                                            args.neg_sample_size,
140
                                                            args.neg_chunk_size,
141
                                                            args.eval_filter,
142
                                                            mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
143
                                                            num_workers=num_workers,
144
145
146
147
148
149
                                                            rank=i, ranks=args.num_proc)
            test_sampler_heads.append(test_sampler_head)
            test_sampler_tails.append(test_sampler_tail)
    else:
        test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
                                                        args.neg_sample_size,
150
                                                        args.neg_chunk_size,
151
                                                        args.eval_filter,
152
                                                        mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
153
                                                        num_workers=num_workers,
154
155
156
                                                        rank=0, ranks=1)
        test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
                                                        args.neg_sample_size,
157
                                                        args.neg_chunk_size,
158
                                                        args.eval_filter,
159
                                                        mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
160
                                                        num_workers=num_workers,
161
162
163
164
165
166
167
168
169
170
171
172
173
                                                        rank=0, ranks=1)

    # load model
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    ckpt_path = args.model_path
    model = load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)

    if args.num_proc > 1:
        model.share_memory()
    # test
    args.step = 0
    args.max_step = 0
Da Zheng's avatar
Da Zheng committed
174
    start = time.time()
175
    if args.num_proc > 1:
176
        queue = mp.Queue(args.num_proc)
177
178
        procs = []
        for i in range(args.num_proc):
179
180
181
182
183
184
            proc = mp.Process(target=test_mp, args=(args,
                                                    model,
                                                    [test_sampler_heads[i], test_sampler_tails[i]],
                                                    i,
                                                    'Test',
                                                    queue))
185
186
            procs.append(proc)
            proc.start()
187
188

        total_metrics = {}
189
190
        metrics = {}
        logs = []
191
        for i in range(args.num_proc):
192
193
194
195
196
            log = queue.get()
            logs = logs + log

        for metric in logs[0].keys():
            metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
197
198
        for k, v in metrics.items():
            print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
199
200
201

        for proc in procs:
            proc.join()
202
203
    else:
        test(args, model, [test_sampler_head, test_sampler_tail])
Da Zheng's avatar
Da Zheng committed
204
    print('Test takes {:.3f} seconds'.format(time.time() - start))
205
206
207
208
209

if __name__ == '__main__':
    args = ArgParser().parse_args()
    main(args)