"tests/vscode:/vscode.git/clone" did not exist on "2ab170499eaaf7adfa24a80e0e2717c916f598f1"
eval.py 9.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
from dataloader import EvalDataset, TrainDataset
from dataloader import get_dataset

import argparse
import os
import logging
import time
import pickle

Da Zheng's avatar
Da Zheng committed
10
backend = os.environ.get('DGLBACKEND', 'pytorch')
11
if backend.lower() == 'mxnet':
VoVAllen's avatar
VoVAllen committed
12
    import multiprocessing as mp
13
14
15
    from train_mxnet import load_model_from_checkpoint
    from train_mxnet import test
else:
VoVAllen's avatar
VoVAllen committed
16
    import torch.multiprocessing as mp
17
    from train_pytorch import load_model_from_checkpoint
18
    from train_pytorch import test, test_mp
19
20
21
22
23
24

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
25
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransH', 'TransR', 'TransD',
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE'],
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
        self.add_argument('--format', type=str, default='1',
                          help='the format of the dataset.')
        self.add_argument('--model_path', type=str, default='ckpts',
                          help='the place where models are saved')

        self.add_argument('--batch_size', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=-1,
                          help='negative sampling size for testing')
41
42
43
44
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sampling proportional to vertex degree for testing')
        self.add_argument('--neg_chunk_size', type=int, default=-1,
                          help='chunk size of the negative edges.')
45
46
47
48
49
50
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
51
52
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
53

54
55
        self.add_argument('--gpu', type=int, default=[-1], nargs='+',
                          help='a list of active gpu ids, e.g. 0')
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('--seed', type=int, default=0,
                          help='set random seed fro reproducibility')

        self.add_argument('--num_worker', type=int, default=16,
                          help='number of workers used for loading data')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')

    def parse_args(self):
        args = super().parse_args()
        return args

def get_logger(args):
    if not os.path.exists(args.model_path):
        raise Exception('No existing model_path: ' + args.model_path)

    log_file = os.path.join(args.model_path, 'eval.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger

def main(args):
93
94
95
96
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

97
98
99
100
101
102
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    args.pickle_graph = False
    args.train = False
    args.valid = False
    args.test = True
103
104
    args.strict_rel_part = False
    args.async_update = False
105
106
107
108
109
110
    args.batch_size_eval = args.batch_size

    logger = get_logger(args)
    # Here we want to use the regualr negative sampler because we need to ensure that
    # all positive edges are excluded.
    eval_dataset = EvalDataset(dataset, args)
111

112
    args.neg_sample_size_test = args.neg_sample_size
113
    args.neg_deg_sample_eval = args.neg_deg_sample
114
115
    if args.neg_sample_size < 0:
        args.neg_sample_size_test = args.neg_sample_size = eval_dataset.g.number_of_nodes()
116
117
118
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size

Da Zheng's avatar
Da Zheng committed
119
120
121
122
123
    num_workers = args.num_worker
    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
124
125
126
127
128
129
    if args.num_proc > 1:
        test_sampler_tails = []
        test_sampler_heads = []
        for i in range(args.num_proc):
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
                                                            args.neg_sample_size,
130
                                                            args.neg_chunk_size,
131
                                                            args.eval_filter,
132
                                                            mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
133
                                                            num_workers=num_workers,
134
135
136
                                                            rank=i, ranks=args.num_proc)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
                                                            args.neg_sample_size,
137
                                                            args.neg_chunk_size,
138
                                                            args.eval_filter,
139
                                                            mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
140
                                                            num_workers=num_workers,
141
142
143
144
145
146
                                                            rank=i, ranks=args.num_proc)
            test_sampler_heads.append(test_sampler_head)
            test_sampler_tails.append(test_sampler_tail)
    else:
        test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
                                                        args.neg_sample_size,
147
                                                        args.neg_chunk_size,
148
                                                        args.eval_filter,
149
                                                        mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
150
                                                        num_workers=num_workers,
151
152
153
                                                        rank=0, ranks=1)
        test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
                                                        args.neg_sample_size,
154
                                                        args.neg_chunk_size,
155
                                                        args.eval_filter,
156
                                                        mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
157
                                                        num_workers=num_workers,
158
159
160
161
162
163
164
165
166
167
168
169
170
                                                        rank=0, ranks=1)

    # load model
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    ckpt_path = args.model_path
    model = load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)

    if args.num_proc > 1:
        model.share_memory()
    # test
    args.step = 0
    args.max_step = 0
Da Zheng's avatar
Da Zheng committed
171
    start = time.time()
172
    if args.num_proc > 1:
173
        queue = mp.Queue(args.num_proc)
174
175
        procs = []
        for i in range(args.num_proc):
176
177
178
179
180
181
            proc = mp.Process(target=test_mp, args=(args,
                                                    model,
                                                    [test_sampler_heads[i], test_sampler_tails[i]],
                                                    i,
                                                    'Test',
                                                    queue))
182
183
            procs.append(proc)
            proc.start()
184
185

        total_metrics = {}
186
187
        metrics = {}
        logs = []
188
        for i in range(args.num_proc):
189
190
191
192
193
            log = queue.get()
            logs = logs + log

        for metric in logs[0].keys():
            metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
194
195
        for k, v in metrics.items():
            print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
196
197
198

        for proc in procs:
            proc.join()
199
200
    else:
        test(args, model, [test_sampler_head, test_sampler_tail])
Da Zheng's avatar
Da Zheng committed
201
    print('Test takes {:.3f} seconds'.format(time.time() - start))
202
203
204
205
206

if __name__ == '__main__':
    args = ArgParser().parse_args()
    main(args)