eval.py 9.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
from dataloader import EvalDataset, TrainDataset
from dataloader import get_dataset

import argparse
import os
import logging
import time
import pickle

10
11
from utils import get_compatible_batch_size

Da Zheng's avatar
Da Zheng committed
12
backend = os.environ.get('DGLBACKEND', 'pytorch')
13
if backend.lower() == 'mxnet':
VoVAllen's avatar
VoVAllen committed
14
    import multiprocessing as mp
15
16
17
    from train_mxnet import load_model_from_checkpoint
    from train_mxnet import test
else:
VoVAllen's avatar
VoVAllen committed
18
    import torch.multiprocessing as mp
19
    from train_pytorch import load_model_from_checkpoint
20
    from train_pytorch import test, test_mp
21
22
23
24
25
26

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
27
28
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
29
30
31
32
33
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
34
35
36
37
38
        self.add_argument('--format', type=str, default='built_in',
                          help='the format of the dataset, it can be built_in,'\
                                'raw_udd_{htr} and udd_{htr}')
        self.add_argument('--data_files', type=str, default=None, nargs='+',
                          help='a list of data files, e.g. entity relation train valid test')
39
40
        self.add_argument('--model_path', type=str, default='ckpts',
                          help='the place where models are saved')
41
        self.add_argument('--batch_size_eval', type=int, default=8,
42
                          help='batch size used for eval and test')
43
        self.add_argument('--neg_sample_size_eval', type=int, default=-1,
44
                          help='negative sampling size for testing')
45
        self.add_argument('--neg_deg_sample_eval', action='store_true',
46
                          help='negative sampling proportional to vertex degree for testing')
47
48
49
50
51
52
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
53
54
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
55
56
        self.add_argument('--gpu', type=int, default=[-1], nargs='+',
                          help='a list of active gpu ids, e.g. 0')
57
58
59
60
61
62
63
64
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')
65
66
        self.add_argument('--num_thread', type=int, default=1,
                          help='number of thread used')
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    def parse_args(self):
        args = super().parse_args()
        return args

def get_logger(args):
    if not os.path.exists(args.model_path):
        raise Exception('No existing model_path: ' + args.model_path)

    log_file = os.path.join(args.model_path, 'eval.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger

90

91
def main(args):
92
    args.eval_filter = not args.no_eval_filter
93
    if args.neg_deg_sample_eval:
94
95
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

96
    # load dataset and samplers
97
    dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
98
99
100
101
    args.pickle_graph = False
    args.train = False
    args.valid = False
    args.test = True
102
    args.strict_rel_part = False
103
    args.soft_rel_part = False
104
    args.async_update = False
105
106
107
108
109

    logger = get_logger(args)
    # Here we want to use the regualr negative sampler because we need to ensure that
    # all positive edges are excluded.
    eval_dataset = EvalDataset(dataset, args)
110

111
112
113
    if args.neg_sample_size_eval < 0:
        args.neg_sample_size_eval = args.neg_sample_size = eval_dataset.g.number_of_nodes()
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval, args.neg_sample_size_eval)
114

115
    args.num_workers = 8 # fix num_workers to 8
116
117
118
119
    if args.num_proc > 1:
        test_sampler_tails = []
        test_sampler_heads = []
        for i in range(args.num_proc):
120
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
121
122
                                                            args.neg_sample_size_eval,
                                                            args.neg_sample_size_eval,
123
                                                            args.eval_filter,
124
                                                            mode='chunk-head',
125
                                                            num_workers=args.num_workers,
126
                                                            rank=i, ranks=args.num_proc)
127
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
128
129
                                                            args.neg_sample_size_eval,
                                                            args.neg_sample_size_eval,
130
                                                            args.eval_filter,
131
                                                            mode='chunk-tail',
132
                                                            num_workers=args.num_workers,
133
134
135
136
                                                            rank=i, ranks=args.num_proc)
            test_sampler_heads.append(test_sampler_head)
            test_sampler_tails.append(test_sampler_tail)
    else:
137
        test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
138
139
                                                        args.neg_sample_size_eval,
                                                        args.neg_sample_size_eval,
140
                                                        args.eval_filter,
141
                                                        mode='chunk-head',
142
                                                        num_workers=args.num_workers,
143
                                                        rank=0, ranks=1)
144
        test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
145
146
                                                        args.neg_sample_size_eval,
                                                        args.neg_sample_size_eval,
147
                                                        args.eval_filter,
148
                                                        mode='chunk-tail',
149
                                                        num_workers=args.num_workers,
150
151
152
153
154
155
156
157
158
159
160
161
162
                                                        rank=0, ranks=1)

    # load model
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    ckpt_path = args.model_path
    model = load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)

    if args.num_proc > 1:
        model.share_memory()
    # test
    args.step = 0
    args.max_step = 0
Da Zheng's avatar
Da Zheng committed
163
    start = time.time()
164
    if args.num_proc > 1:
165
        queue = mp.Queue(args.num_proc)
166
167
        procs = []
        for i in range(args.num_proc):
168
169
170
171
172
173
            proc = mp.Process(target=test_mp, args=(args,
                                                    model,
                                                    [test_sampler_heads[i], test_sampler_tails[i]],
                                                    i,
                                                    'Test',
                                                    queue))
174
175
            procs.append(proc)
            proc.start()
176
177

        total_metrics = {}
178
179
        metrics = {}
        logs = []
180
        for i in range(args.num_proc):
181
182
183
184
185
            log = queue.get()
            logs = logs + log

        for metric in logs[0].keys():
            metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
186
187
        for k, v in metrics.items():
            print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
188
189
190

        for proc in procs:
            proc.join()
191
192
    else:
        test(args, model, [test_sampler_head, test_sampler_tail])
Da Zheng's avatar
Da Zheng committed
193
    print('Test takes {:.3f} seconds'.format(time.time() - start))
194
195
196
197
198

if __name__ == '__main__':
    args = ArgParser().parse_args()
    main(args)