train_mxnet.py 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# -*- coding: utf-8 -*-
#
# setup.py
#
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from models import KEModel

import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd

import os
import logging
import time
import json

def load_model(logger, args, n_entities, n_relations, ckpt=None):
    model = KEModel(args, args.model_name, n_entities, n_relations,
                    args.hidden_dim, args.gamma,
                    double_entity_emb=args.double_ent, double_relation_emb=args.double_rel)
    if ckpt is not None:
36
        assert False, "We do not support loading model emb for genernal Embedding"
37
38
39
40
41
42
43
44
45

    logger.info('Load model {}'.format(args.model_name))
    return model

def load_model_from_checkpoint(logger, args, n_entities, n_relations, ckpt_path):
    model = load_model(logger, args, n_entities, n_relations)
    model.load_emb(ckpt_path, args.dataset)
    return model

46
47
def train(args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, barrier=None):
    assert args.num_proc <= 1, "MXNet KGE does not support multi-process now"
48
    assert args.rel_part == False, "No need for relation partition in single process for MXNet KGE"
49
50
51
52
53
    logs = []

    for arg in vars(args):
        logging.info('{:20}:{}'.format(arg, getattr(args, arg)))

54
55
56
57
58
    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

59
60
61
    if args.strict_rel_part:
        model.prepare_relation(mx.gpu(gpu_id))

62
    start = time.time()
63
    for step in range(0, args.max_step):
64
65
66
        pos_g, neg_g = next(train_sampler)
        args.step = step
        with mx.autograd.record():
67
            loss, log = model.forward(pos_g, neg_g, gpu_id)
68
69
        loss.backward()
        logs.append(log)
70
        model.update(gpu_id)
71
72
73
74
75
76
77
78
79
80
81
82
83

        if step % args.log_interval == 0:
            for k in logs[0].keys():
                v = sum(l[k] for l in logs) / len(logs)
                print('[Train]({}/{}) average {}: {}'.format(step, args.max_step, k, v))
            logs = []
            print(time.time() - start)
            start = time.time()

        if args.valid and step % args.eval_interval == 0 and step > 1 and valid_samplers is not None:
            start = time.time()
            test(args, model, valid_samplers, mode='Valid')
            print('test:', time.time() - start)
84
85
86
    if args.strict_rel_part:
        model.writeback_relation(rank, rel_parts)

87
88
89
    # clear cache
    logs = []

90
def test(args, model, test_samplers, rank=0, mode='Test', queue=None):
91
    assert args.num_proc <= 1, "MXNet KGE does not support multi-process now"
92
93
    logs = []

94
95
96
97
98
    if len(args.gpu) > 0:
        gpu_id = args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]
    else:
        gpu_id = -1

99
100
101
    if args.strict_rel_part:
        model.load_relation(mx.gpu(gpu_id))

102
103
104
105
    for sampler in test_samplers:
        #print('Number of tests: ' + len(sampler))
        count = 0
        for pos_g, neg_g in sampler:
106
            model.forward_test(pos_g, neg_g, logs, gpu_id)
107
108
109
110
111
112
113

    metrics = {}
    if len(logs) > 0:
        for metric in logs[0].keys():
            metrics[metric] = sum([log[metric] for log in logs]) / len(logs)

    for k, v in metrics.items():
114
        print('{} average {}: {}'.format(mode, k, v))
115
116
    for i in range(len(test_samplers)):
        test_samplers[i] = test_samplers[i].reset()