eval.py 8.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
from dataloader import EvalDataset, TrainDataset
from dataloader import get_dataset

import argparse
import os
import logging
import time
import pickle

Da Zheng's avatar
Da Zheng committed
10
backend = os.environ.get('DGLBACKEND', 'pytorch')
11
if backend.lower() == 'mxnet':
VoVAllen's avatar
VoVAllen committed
12
    import multiprocessing as mp
13
14
15
    from train_mxnet import load_model_from_checkpoint
    from train_mxnet import test
else:
VoVAllen's avatar
VoVAllen committed
16
    import torch.multiprocessing as mp
17
18
19
20
21
22
23
24
    from train_pytorch import load_model_from_checkpoint
    from train_pytorch import test

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
25
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransH', 'TransR', 'TransD',
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE'],
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
        self.add_argument('--format', type=str, default='1',
                          help='the format of the dataset.')
        self.add_argument('--model_path', type=str, default='ckpts',
                          help='the place where models are saved')

        self.add_argument('--batch_size', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=-1,
                          help='negative sampling size for testing')
41
42
43
44
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sampling proportional to vertex degree for testing')
        self.add_argument('--neg_chunk_size', type=int, default=-1,
                          help='chunk size of the negative edges.')
45
46
47
48
49
50
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
51
52
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

        self.add_argument('--gpu', type=int, default=-1,
                          help='use GPU')
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('--seed', type=int, default=0,
                          help='set random seed fro reproducibility')

        self.add_argument('--num_worker', type=int, default=16,
                          help='number of workers used for loading data')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')

    def parse_args(self):
        args = super().parse_args()
        return args

def get_logger(args):
    if not os.path.exists(args.model_path):
        raise Exception('No existing model_path: ' + args.model_path)

    log_file = os.path.join(args.model_path, 'eval.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger

def main(args):
93
94
95
96
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

97
98
99
100
101
102
103
104
105
106
107
108
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    args.pickle_graph = False
    args.train = False
    args.valid = False
    args.test = True
    args.batch_size_eval = args.batch_size

    logger = get_logger(args)
    # Here we want to use the regualr negative sampler because we need to ensure that
    # all positive edges are excluded.
    eval_dataset = EvalDataset(dataset, args)
109

110
    args.neg_sample_size_test = args.neg_sample_size
111
    args.neg_deg_sample_eval = args.neg_deg_sample
112
113
    if args.neg_sample_size < 0:
        args.neg_sample_size_test = args.neg_sample_size = eval_dataset.g.number_of_nodes()
114
115
116
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size

Da Zheng's avatar
Da Zheng committed
117
118
119
120
121
    num_workers = args.num_worker
    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
122
123
124
125
126
127
    if args.num_proc > 1:
        test_sampler_tails = []
        test_sampler_heads = []
        for i in range(args.num_proc):
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
                                                            args.neg_sample_size,
128
                                                            args.neg_chunk_size,
129
                                                            args.eval_filter,
130
                                                            mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
131
                                                            num_workers=num_workers,
132
133
134
                                                            rank=i, ranks=args.num_proc)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
                                                            args.neg_sample_size,
135
                                                            args.neg_chunk_size,
136
                                                            args.eval_filter,
137
                                                            mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
138
                                                            num_workers=num_workers,
139
140
141
142
143
144
                                                            rank=i, ranks=args.num_proc)
            test_sampler_heads.append(test_sampler_head)
            test_sampler_tails.append(test_sampler_tail)
    else:
        test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
                                                        args.neg_sample_size,
145
                                                        args.neg_chunk_size,
146
                                                        args.eval_filter,
147
                                                        mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
148
                                                        num_workers=num_workers,
149
150
151
                                                        rank=0, ranks=1)
        test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
                                                        args.neg_sample_size,
152
                                                        args.neg_chunk_size,
153
                                                        args.eval_filter,
154
                                                        mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
155
                                                        num_workers=num_workers,
156
157
158
159
160
161
162
163
164
165
166
167
168
                                                        rank=0, ranks=1)

    # load model
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    ckpt_path = args.model_path
    model = load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)

    if args.num_proc > 1:
        model.share_memory()
    # test
    args.step = 0
    args.max_step = 0
Da Zheng's avatar
Da Zheng committed
169
    start = time.time()
170
    if args.num_proc > 1:
171
        queue = mp.Queue(args.num_proc)
172
173
        procs = []
        for i in range(args.num_proc):
174
175
            proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
                              'Test', queue))
176
177
178
179
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
180
181
182
183
184
185
186
187
188
189
190

        total_metrics = {}
        for i in range(args.num_proc):
            metrics = queue.get()
            for k, v in metrics.items():
                if i == 0:
                    total_metrics[k] = v / args.num_proc
                else:
                    total_metrics[k] += v / args.num_proc
        for k, v in metrics.items():
            print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
191
192
    else:
        test(args, model, [test_sampler_head, test_sampler_tail])
Da Zheng's avatar
Da Zheng committed
193
    print('Test takes {:.3f} seconds'.format(time.time() - start))
194
195
196
197
198
199


if __name__ == '__main__':
    args = ArgParser().parse_args()
    main(args)