"router/vscode:/vscode.git/clone" did not exist on "8aece3bd68e26e4fec520c276785d8c391882787"
eval.py 9.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
from dataloader import EvalDataset, TrainDataset
from dataloader import get_dataset

import argparse
import os
import logging
import time
import pickle

Da Zheng's avatar
Da Zheng committed
10
backend = os.environ.get('DGLBACKEND', 'pytorch')
11
if backend.lower() == 'mxnet':
VoVAllen's avatar
VoVAllen committed
12
    import multiprocessing as mp
13
14
15
    from train_mxnet import load_model_from_checkpoint
    from train_mxnet import test
else:
VoVAllen's avatar
VoVAllen committed
16
    import torch.multiprocessing as mp
17
    from train_pytorch import load_model_from_checkpoint
18
    from train_pytorch import test, test_mp
19
20
21
22
23
24

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
25
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransH', 'TransR', 'TransD',
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE', 'pRotatE'],
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
        self.add_argument('--format', type=str, default='1',
                          help='the format of the dataset.')
        self.add_argument('--model_path', type=str, default='ckpts',
                          help='the place where models are saved')

        self.add_argument('--batch_size', type=int, default=8,
                          help='batch size used for eval and test')
        self.add_argument('--neg_sample_size', type=int, default=-1,
                          help='negative sampling size for testing')
41
42
43
44
        self.add_argument('--neg_deg_sample', action='store_true',
                          help='negative sampling proportional to vertex degree for testing')
        self.add_argument('--neg_chunk_size', type=int, default=-1,
                          help='chunk size of the negative edges.')
45
46
47
48
49
50
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
51
52
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
53

54
55
        self.add_argument('--gpu', type=int, default=[-1], nargs='+',
                          help='a list of active gpu ids, e.g. 0')
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('--seed', type=int, default=0,
                          help='set random seed fro reproducibility')

        self.add_argument('--num_worker', type=int, default=16,
                          help='number of workers used for loading data')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')

    def parse_args(self):
        args = super().parse_args()
        return args

def get_logger(args):
    if not os.path.exists(args.model_path):
        raise Exception('No existing model_path: ' + args.model_path)

    log_file = os.path.join(args.model_path, 'eval.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger

def main(args):
93
94
95
96
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

97
98
99
100
101
102
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    args.pickle_graph = False
    args.train = False
    args.valid = False
    args.test = True
103
    args.strict_rel_part = False
104
    args.soft_rel_part = False
105
    args.async_update = False
106
107
108
109
110
111
    args.batch_size_eval = args.batch_size

    logger = get_logger(args)
    # Here we want to use the regualr negative sampler because we need to ensure that
    # all positive edges are excluded.
    eval_dataset = EvalDataset(dataset, args)
112

113
    args.neg_sample_size_test = args.neg_sample_size
114
    args.neg_deg_sample_eval = args.neg_deg_sample
115
116
    if args.neg_sample_size < 0:
        args.neg_sample_size_test = args.neg_sample_size = eval_dataset.g.number_of_nodes()
117
118
119
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size

Da Zheng's avatar
Da Zheng committed
120
121
122
123
124
    num_workers = args.num_worker
    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
125
126
127
128
129
130
    if args.num_proc > 1:
        test_sampler_tails = []
        test_sampler_heads = []
        for i in range(args.num_proc):
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
                                                            args.neg_sample_size,
131
                                                            args.neg_chunk_size,
132
                                                            args.eval_filter,
133
                                                            mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
134
                                                            num_workers=num_workers,
135
136
137
                                                            rank=i, ranks=args.num_proc)
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
                                                            args.neg_sample_size,
138
                                                            args.neg_chunk_size,
139
                                                            args.eval_filter,
140
                                                            mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
141
                                                            num_workers=num_workers,
142
143
144
145
146
147
                                                            rank=i, ranks=args.num_proc)
            test_sampler_heads.append(test_sampler_head)
            test_sampler_tails.append(test_sampler_tail)
    else:
        test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
                                                        args.neg_sample_size,
148
                                                        args.neg_chunk_size,
149
                                                        args.eval_filter,
150
                                                        mode='chunk-head',
Da Zheng's avatar
Da Zheng committed
151
                                                        num_workers=num_workers,
152
153
154
                                                        rank=0, ranks=1)
        test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
                                                        args.neg_sample_size,
155
                                                        args.neg_chunk_size,
156
                                                        args.eval_filter,
157
                                                        mode='chunk-tail',
Da Zheng's avatar
Da Zheng committed
158
                                                        num_workers=num_workers,
159
160
161
162
163
164
165
166
167
168
169
170
171
                                                        rank=0, ranks=1)

    # load model
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    ckpt_path = args.model_path
    model = load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)

    if args.num_proc > 1:
        model.share_memory()
    # test
    args.step = 0
    args.max_step = 0
Da Zheng's avatar
Da Zheng committed
172
    start = time.time()
173
    if args.num_proc > 1:
174
        queue = mp.Queue(args.num_proc)
175
176
        procs = []
        for i in range(args.num_proc):
177
178
179
180
181
182
            proc = mp.Process(target=test_mp, args=(args,
                                                    model,
                                                    [test_sampler_heads[i], test_sampler_tails[i]],
                                                    i,
                                                    'Test',
                                                    queue))
183
184
            procs.append(proc)
            proc.start()
185
186

        total_metrics = {}
187
188
        metrics = {}
        logs = []
189
        for i in range(args.num_proc):
190
191
192
193
194
            log = queue.get()
            logs = logs + log

        for metric in logs[0].keys():
            metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
195
196
        for k, v in metrics.items():
            print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
197
198
199

        for proc in procs:
            proc.join()
200
201
    else:
        test(args, model, [test_sampler_head, test_sampler_tail])
Da Zheng's avatar
Da Zheng committed
202
    print('Test takes {:.3f} seconds'.format(time.time() - start))
203
204
205
206
207

if __name__ == '__main__':
    args = ArgParser().parse_args()
    main(args)