eval.py 9.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# -*- coding: utf-8 -*-
#
# setup.py
#
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

20
21
22
23
24
25
26
27
28
from dataloader import EvalDataset, TrainDataset
from dataloader import get_dataset

import argparse
import os
import logging
import time
import pickle

29
30
from utils import get_compatible_batch_size

Da Zheng's avatar
Da Zheng committed
31
backend = os.environ.get('DGLBACKEND', 'pytorch')
32
if backend.lower() == 'mxnet':
VoVAllen's avatar
VoVAllen committed
33
    import multiprocessing as mp
34
35
36
    from train_mxnet import load_model_from_checkpoint
    from train_mxnet import test
else:
VoVAllen's avatar
VoVAllen committed
37
    import torch.multiprocessing as mp
38
    from train_pytorch import load_model_from_checkpoint
39
    from train_pytorch import test, test_mp
40
41
42
43
44
45

class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()

        self.add_argument('--model_name', default='TransE',
46
47
                          choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
                                   'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
48
49
50
51
52
                          help='model to use')
        self.add_argument('--data_path', type=str, default='data',
                          help='root path of all dataset')
        self.add_argument('--dataset', type=str, default='FB15k',
                          help='dataset name, under data_path')
53
54
55
56
57
        self.add_argument('--format', type=str, default='built_in',
                          help='the format of the dataset, it can be built_in,'\
                                'raw_udd_{htr} and udd_{htr}')
        self.add_argument('--data_files', type=str, default=None, nargs='+',
                          help='a list of data files, e.g. entity relation train valid test')
58
59
        self.add_argument('--model_path', type=str, default='ckpts',
                          help='the place where models are saved')
60
        self.add_argument('--batch_size_eval', type=int, default=8,
61
                          help='batch size used for eval and test')
62
        self.add_argument('--neg_sample_size_eval', type=int, default=-1,
63
                          help='negative sampling size for testing')
64
        self.add_argument('--neg_deg_sample_eval', action='store_true',
65
                          help='negative sampling proportional to vertex degree for testing')
66
67
68
69
70
71
        self.add_argument('--hidden_dim', type=int, default=256,
                          help='hidden dim used by relation and entity')
        self.add_argument('-g', '--gamma', type=float, default=12.0,
                          help='margin value')
        self.add_argument('--eval_percent', type=float, default=1,
                          help='sample some percentage for evaluation.')
72
73
        self.add_argument('--no_eval_filter', action='store_true',
                          help='do not filter positive edges among negative edges for evaluation')
74
75
        self.add_argument('--gpu', type=int, default=[-1], nargs='+',
                          help='a list of active gpu ids, e.g. 0')
76
77
78
79
80
81
82
83
        self.add_argument('--mix_cpu_gpu', action='store_true',
                          help='mix CPU and GPU training')
        self.add_argument('-de', '--double_ent', action='store_true',
                          help='double entitiy dim for complex number')
        self.add_argument('-dr', '--double_rel', action='store_true',
                          help='double relation dim for complex number')
        self.add_argument('--num_proc', type=int, default=1,
                          help='number of process used')
84
85
        self.add_argument('--num_thread', type=int, default=1,
                          help='number of thread used')
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

    def parse_args(self):
        args = super().parse_args()
        return args

def get_logger(args):
    if not os.path.exists(args.model_path):
        raise Exception('No existing model_path: ' + args.model_path)

    log_file = os.path.join(args.model_path, 'eval.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )

    logger = logging.getLogger(__name__)
    print("Logs are being recorded at: {}".format(log_file))
    return logger

109

110
def main(args):
111
    args.eval_filter = not args.no_eval_filter
112
    if args.neg_deg_sample_eval:
113
114
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

115
    # load dataset and samplers
116
    dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)
117
118
119
120
    args.pickle_graph = False
    args.train = False
    args.valid = False
    args.test = True
121
    args.strict_rel_part = False
122
    args.soft_rel_part = False
123
    args.async_update = False
124
125
126
127
128

    logger = get_logger(args)
    # Here we want to use the regualr negative sampler because we need to ensure that
    # all positive edges are excluded.
    eval_dataset = EvalDataset(dataset, args)
129

130
131
132
    if args.neg_sample_size_eval < 0:
        args.neg_sample_size_eval = args.neg_sample_size = eval_dataset.g.number_of_nodes()
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval, args.neg_sample_size_eval)
133

134
    args.num_workers = 8 # fix num_workers to 8
135
136
137
138
    if args.num_proc > 1:
        test_sampler_tails = []
        test_sampler_heads = []
        for i in range(args.num_proc):
139
            test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
140
141
                                                            args.neg_sample_size_eval,
                                                            args.neg_sample_size_eval,
142
                                                            args.eval_filter,
143
                                                            mode='chunk-head',
144
                                                            num_workers=args.num_workers,
145
                                                            rank=i, ranks=args.num_proc)
146
            test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
147
148
                                                            args.neg_sample_size_eval,
                                                            args.neg_sample_size_eval,
149
                                                            args.eval_filter,
150
                                                            mode='chunk-tail',
151
                                                            num_workers=args.num_workers,
152
153
154
155
                                                            rank=i, ranks=args.num_proc)
            test_sampler_heads.append(test_sampler_head)
            test_sampler_tails.append(test_sampler_tail)
    else:
156
        test_sampler_head = eval_dataset.create_sampler('test', args.batch_size_eval,
157
158
                                                        args.neg_sample_size_eval,
                                                        args.neg_sample_size_eval,
159
                                                        args.eval_filter,
160
                                                        mode='chunk-head',
161
                                                        num_workers=args.num_workers,
162
                                                        rank=0, ranks=1)
163
        test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size_eval,
164
165
                                                        args.neg_sample_size_eval,
                                                        args.neg_sample_size_eval,
166
                                                        args.eval_filter,
167
                                                        mode='chunk-tail',
168
                                                        num_workers=args.num_workers,
169
170
171
172
173
174
175
176
177
178
179
180
181
                                                        rank=0, ranks=1)

    # load model
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    ckpt_path = args.model_path
    model = load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path)

    if args.num_proc > 1:
        model.share_memory()
    # test
    args.step = 0
    args.max_step = 0
Da Zheng's avatar
Da Zheng committed
182
    start = time.time()
183
    if args.num_proc > 1:
184
        queue = mp.Queue(args.num_proc)
185
186
        procs = []
        for i in range(args.num_proc):
187
188
189
190
191
192
            proc = mp.Process(target=test_mp, args=(args,
                                                    model,
                                                    [test_sampler_heads[i], test_sampler_tails[i]],
                                                    i,
                                                    'Test',
                                                    queue))
193
194
            procs.append(proc)
            proc.start()
195
196

        total_metrics = {}
197
198
        metrics = {}
        logs = []
199
        for i in range(args.num_proc):
200
201
202
203
204
            log = queue.get()
            logs = logs + log

        for metric in logs[0].keys():
            metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
205
206
        for k, v in metrics.items():
            print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
207
208
209

        for proc in procs:
            proc.join()
210
211
    else:
        test(args, model, [test_sampler_head, test_sampler_tail])
Da Zheng's avatar
Da Zheng committed
212
    print('Test takes {:.3f} seconds'.format(time.time() - start))
213
214
215
216
217

if __name__ == '__main__':
    args = ArgParser().parse_args()
    main(args)