topn_words_dep.py 5.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import sys
import os
import torch
import argparse
import numpy as np

tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)

from tencentpretrain.embeddings import *
from tencentpretrain.encoders import *
from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.utils.vocab import Vocab
from tencentpretrain.opts import model_opts, tokenizer_opts


class SequenceEncoder(torch.nn.Module):
    def __init__(self, args):
        super(SequenceEncoder, self).__init__()
        # self.embedding = str2embedding[args.embedding](args, len(args.tokenizer.vocab))
        self.embedding = Embedding(args)
        for embedding_name in args.embedding:
            tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab))
            self.embedding.update(tmp_emb, embedding_name)
        self.encoder = str2encoder[args.encoder](args)

    def forward(self, src, seg):
        emb = self.embedding(src, seg)
        output = self.encoder(emb, seg)

        return output


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    model_opts(parser)

    parser.add_argument("--load_model_path", default=None, type=str,
                        help="Path of the input model.")
    parser.add_argument("--cand_vocab_path", default=None, type=str,
                        help="Path of the candidate vocabulary file.")
    parser.add_argument("--test_path", type=str, required=True,
                        help="Path of the target word an its context.")
    parser.add_argument("--config_path", default="models/bert/base_config.json", type=str,
                        help="Path of the config file.")

    tokenizer_opts(parser)

    parser.add_argument("--batch_size", type=int, default=64,
                        help="Batch size.")
    parser.add_argument("--seq_length", type=int, default=128,
                        help="Sequence length.")

    parser.add_argument("--topn", type=int, default=15)

    args = parser.parse_args()
    args = load_hyperparam(args)

    args.spm_model_path = None

    vocab = Vocab()
    vocab.load(args.vocab_path)

    cand_vocab = Vocab()
    cand_vocab.load(args.cand_vocab_path)

    args.tokenizer = str2tokenizer[args.tokenizer](args)

    model = SequenceEncoder(args)    
 
    pretrained_model = torch.load(args.load_model_path)
    model.load_state_dict(pretrained_model, strict=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)
    model = model.to(device)
    model.eval()

    PAD_ID = args.tokenizer.vocab.get(PAD_TOKEN)
    with open(args.test_path, mode="r", encoding="utf-8") as f:
        for line in f:
            line = line.strip().split("\t")
            if len(line) != 2:
                continue
            target_word, context = line[0], line[1]
            print("Original sentence: " + context)
            print("Target word: " + target_word)
            src = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(context))
            seg = [1] * len(src)
            if len(src) > args.seq_length:
                src = src[:args.seq_length]
                seg = seg[:args.seq_length]
            while len(src) < args.seq_length:
                src.append(PAD_ID)
                seg.append(PAD_ID)

            target_word_id = vocab.get(target_word)
            if target_word_id in src:
                position = src.index(target_word_id)
            else:
                print("The target word is not in the sentence.")
                continue

            output = model(torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device))
            output = output.cpu().data.numpy()
            output = output.reshape([args.seq_length, -1])
            target_embedding = output[position, :]
            target_embedding = target_embedding.reshape(1, -1).astype("float")

            cand_words_batch, cand_embeddings = [], []
            for i, word in enumerate(cand_vocab.i2w):
                cand_words_batch.append(vocab.w2i.get(word))
                if len(cand_words_batch) == args.batch_size or i == (len(cand_vocab.i2w)-1):
                    src_batch = torch.LongTensor([src] * len(cand_words_batch))
                    seg_batch = [seg] * len(cand_words_batch)
                    src_batch[:, position] = torch.LongTensor(cand_words_batch)
                    output = model(torch.LongTensor(src_batch).to(device), torch.LongTensor(seg_batch).to(device))
                    output = output.cpu().data.numpy()
                    output = np.reshape(output, (len(output), args.seq_length, -1))
                    cand_embeddings.extend(output[:, position, :].tolist())
                    cand_words_batch = []

            sims = torch.nn.functional.cosine_similarity(torch.FloatTensor(target_embedding), \
                                                         torch.FloatTensor(cand_embeddings))
           
            sorted_ids = torch.argsort(sims, descending=True)
            for j in sorted_ids[1: args.topn + 1]:
                print(cand_vocab.i2w[j].strip() + "\t" + str(sims[j].item()))