task_sentence_embedding_unsup_PromptBert.py 9.03 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#! -*- coding: utf-8 -*-
# promptbert实现sentence embedding
# 官方项目:https://github.com/kongds/Prompt-BERT
# 参考项目:https://github.com/Macielyoung/sentence_representation_matching
# |     solution    |   ATEC  |  BQ  |  LCQMC  |  PAWSX  |  STS-B  |
# |    PromptBert   |  33.98  | 49.89|  73.18  |  13.30  |  73.42  |

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from bert4torch.tokenizers import Tokenizer, load_vocab
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.snippets import ListDataset, sequence_padding, Callback
from torch.utils.data import DataLoader
from scipy.stats import pearsonr, spearmanr
import numpy as np
import sys
import jieba
jieba.initialize()


# =============================基本参数=============================
model_type, task_name, dropout_rate = sys.argv[1:]  # 传入参数
# model_type, task_name, dropout_rate = 'BERT', 'ATEC', 0.3  # debug使用
print(model_type, task_name, dropout_rate)

assert model_type in {'BERT', 'RoBERTa', 'NEZHA', 'RoFormer', 'SimBERT'}
assert task_name in {'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STS-B'}
if model_type in {'BERT', 'RoBERTa', 'SimBERT'}:
    model_name = 'bert'
elif model_type in {'RoFormer'}:
    model_name = 'roformer'
elif model_type in {'NEZHA'}:
    model_name = 'nezha'

dropout_rate = float(dropout_rate)
batch_size = 32
template_len = 15

if task_name == 'PAWSX':
    maxlen = 128 + template_len
else:
    maxlen = 64 + template_len

# bert配置
model_dir = {
    'BERT': 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12',
    'RoBERTa': 'F:/Projects/pretrain_ckpt/robert/[hit_torch_base]--chinese-roberta-wwm-ext-base',
    'NEZHA': 'F:/Projects/pretrain_ckpt/nezha/[huawei_noah_torch_base]--nezha-cn-base',
    'RoFormer': 'F:/Projects/pretrain_ckpt/roformer/[sushen_torch_base]--roformer_v1_base',
    'SimBERT': 'F:/Projects/pretrain_ckpt/simbert/[sushen_torch_base]--simbert_chinese_base',
}[model_type]

config_path = f'{model_dir}/bert_config.json' if model_type == 'BERT' else f'{model_dir}/config.json'
checkpoint_path = f'{model_dir}/pytorch_model.bin'
dict_path = f'{model_dir}/vocab.txt'
data_path = 'F:/Projects/data/corpus/sentence_embedding/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# =============================加载数据集=============================
# 建立分词器
if model_type in ['RoFormer']:
    tokenizer = Tokenizer(dict_path, do_lower_case=True, pre_tokenize=lambda s: jieba.lcut(s, HMM=False), add_special_tokens='[X]')
else:
    tokenizer = Tokenizer(dict_path, do_lower_case=True, add_special_tokens='[X]')

replace_token = "[X]"
mask_token = "[MASK]"
prompt_templates = ['"{}" 的意思为[MASK]'.format(replace_token), '"{}"这句话的意思是[MASK]'.format(replace_token)]
tao = 0.05

token_dict = load_vocab(dict_path)
compound_tokens = [[len(token_dict)]]
token_dict['[X]'] = len(token_dict)

# 加载数据集
def load_data(filenames):
    D = []
    for filename in filenames:
        with open(filename, 'r', encoding='utf-8') as f:
            for line in tqdm(f.readlines(), desc='Load data'):
                cache = line.split('\t')
                text1, text2, label = cache[0][:maxlen-template_len], cache[1][:maxlen-template_len], cache[-1]
                for text in [text1, text2]:
                    sentence_pair = []
                    for template in prompt_templates:
                        sent_num = len(tokenizer.tokenize(text))
                        prompt_sent = template.replace(replace_token, text)
                        template_sent = template.replace(replace_token, replace_token * sent_num)
                        sentence_pair.extend([prompt_sent, template_sent])
                    D.append((sentence_pair, int(label)))
    return D

all_names = [f'{data_path}{task_name}/{task_name}.{f}.data' for f in ['train', 'valid', 'test']]
print(all_names)
train_texts = load_data(all_names)
valid_texts = list(zip(train_texts[::2], train_texts[1::2]))

if task_name != 'PAWSX':
    np.random.shuffle(train_texts)
    train_texts = train_texts[:10000]

# 加载训练数据集
def collate_fn(batch):
    batch_tensor = [[] for _ in range(4)]
    for prompt_data, _ in batch:
        for i, item in enumerate(prompt_data):
            batch_tensor[i].append(tokenizer.encode(item, maxlen=maxlen)[0])
    for i, item in enumerate(batch_tensor):
        batch_tensor[i] = torch.tensor(sequence_padding(item, maxlen), dtype=torch.long, device=device)
    labels = torch.arange(batch_tensor[0].size(0), device=device)
    return batch_tensor, labels

train_dataloader = DataLoader(ListDataset(data=train_texts), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# 加载测试数据集
def collate_fn_test(batch):
    text1_ids, text2_ids, labels = [], [], []
    for text1, text2 in batch:
        label = text1[-1]
        text1, text2 = text1[0][0], text2[0][0]
        text1_ids.append(tokenizer.encode(text1, maxlen=maxlen)[0])
        text2_ids.append(tokenizer.encode(text2, maxlen=maxlen)[0])
        labels.append(label)
    text1_ids = torch.tensor(sequence_padding(text1_ids), dtype=torch.long, device=device)
    text2_ids = torch.tensor(sequence_padding(text2_ids), dtype=torch.long, device=device)
    labels = torch.tensor(labels, dtype=torch.long, device=device)
    return [text1_ids, text2_ids], labels

valid_dataloader = DataLoader(ListDataset(data=valid_texts), batch_size=batch_size, collate_fn=collate_fn_test) 

# =============================定义模型=============================
class PromptBert(BaseModel):
    def __init__(self, scale=20.0):
        super().__init__()
        self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, model=model_name, 
                                            dropout_rate=dropout_rate, segment_vocab_size=0, compound_tokens=compound_tokens)
        self.scale = scale

    def forward(self, prompt0_input, template0_input, prompt1_input, template1_input):
        embeddings_a = self.get_sentence_embedding(prompt0_input, template0_input)
        embeddings_b = self.get_sentence_embedding(prompt1_input, template1_input)
        scores = self.cos_sim(embeddings_a, embeddings_b) * self.scale  # [btz, btz]
        return scores

    def get_sentence_embedding(self, prompt_input_ids, template_input_ids):
        prompt_mask_embedding = self.get_mask_embedding(prompt_input_ids)
        template_mask_embedding = self.get_mask_embedding(template_input_ids)
        # 在计算损失函数时为了消除Prompt模板影响,通过替换模板后的句子[MASK]获取的表征减去模板中[MASK]获取的表征来得到句子向量表征
        sentence_embedding = prompt_mask_embedding - template_mask_embedding
        return sentence_embedding

    def get_mask_embedding(self, input_ids):
        last_hidden_state = self.bert([input_ids])
        mask_index = (input_ids == tokenizer._token_mask_id).long()
        input_mask_expanded = mask_index.unsqueeze(-1).expand(last_hidden_state.size()).float()
        mask_embedding = torch.sum(last_hidden_state * input_mask_expanded, 1)
        return mask_embedding
    
    def predict(self, input_ids):
        self.eval()
        with torch.no_grad():
            mask_embedding = self.get_mask_embedding(input_ids)
        return mask_embedding
    
    @staticmethod
    def cos_sim(a, b):
        a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
        b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
        return torch.mm(a_norm, b_norm.transpose(0, 1))

model = PromptBert().to(device)
       
# 定义使用的loss和optimizer,这里支持自定义
model.compile(
    loss=nn.CrossEntropyLoss(),
    optimizer=optim.Adam(model.parameters(), lr=2e-5),
)

class Evaluator(Callback):
    """评估与保存
    """
    def __init__(self):
        self.best_val_sim = 0.

    def on_epoch_end(self, global_step, epoch, logs=None):
        val_sim = self.evaluate(valid_dataloader)
        if val_sim > self.best_val_sim:
            self.best_val_sim = val_sim
            # model.save_weights('best_model.pt')
        print(f'val_sim: {val_sim:.5f}, best_val_sim: {self.best_val_sim:.5f}\n')
    
    @staticmethod
    def evaluate(data):
        embeddings1, embeddings2, labels = [], [], []
        for (text1_ids, text2_ids), label in data:
            embeddings1.append(model.predict(text1_ids))
            embeddings2.append(model.predict(text2_ids))
            labels.append(label)

        embeddings1 = torch.cat(embeddings1)
        embeddings2 = torch.cat(embeddings2)
        labels = torch.cat(labels)

        sims = F.cosine_similarity(embeddings1, embeddings2).cpu().numpy()
        labels = labels.cpu().numpy()
        return spearmanr(sims, labels)[0]

if __name__ == "__main__":
    evaluator = Evaluator()
    model.fit(train_dataloader, epochs=5, steps_per_epoch=None, callbacks=[evaluator])