task_conditional_language_model.py 7.54 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#! -*- coding: utf-8 -*-
# bert做conditional language model任务
# 按类随机生成文本,这个demo的类别是情感极性(正/负)
# 请参考:https://kexue.fm/archives/7124

from pydantic import NoneStrBytes
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.tokenizers import Tokenizer, load_vocab
from bert4torch.snippets import sequence_padding, text_segmentate, Callback, AutoRegressiveDecoder, ListDataset
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn


# 模型配置
maxlen = 128
batch_size = 16
num_classes = 2
epochs = 20

# bert配置
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 加载并精简词表,建立分词器
token_dict, keep_tokens = load_vocab(
    dict_path=dict_path,
    simplified=True,
    startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
)
tokenizer = Tokenizer(token_dict, do_lower_case=True)


# 加载数据集
class MyDataset(ListDataset):
    @staticmethod
    def load_data(filenames):
        """加载数据,并尽量划分为不超过maxlen的句子
        """
        D = []
        seps, strips = u'\n。!?!?;;,, ', u';;,, '
        for filename in filenames:
            with open(filename, encoding='utf-8') as f:
                for l in f:
                    text, label = l.strip().split('\t')
                    for t in text_segmentate(text, maxlen - 2, seps, strips):
                        D.append((t, int(label)))
                    # if len(D) >= 100:
                    #     break
        return D

def collate_fn(batch):
    batch_token_ids, batch_segment_ids, batch_labels = [], [], []
    for text, label in batch:
        token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
        batch_token_ids.append(token_ids)
        batch_segment_ids.append(segment_ids)
        batch_labels.append(label)

    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
    batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
    batch_labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
    return [batch_token_ids, batch_segment_ids, batch_labels], batch_token_ids

# 加载数据集
train_dataloader = DataLoader(MyDataset([
    'F:/Projects/data/corpus/sentence_classification/sentiment/sentiment.train.data',
    'F:/Projects/data/corpus/sentence_classification/sentiment/sentiment.valid.data',
    'F:/Projects/data/corpus/sentence_classification/sentiment/sentiment.test.data']), 
    batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 

# 定义bert上的模型结构
class Model(BaseModel):
    def __init__(self) -> None:
        super().__init__()
        c = nn.Embedding(num_classes, 128)
        self.bert = build_transformer_model(config_path,
                                            checkpoint_path,
                                            with_mlm=True,
                                            application='lm',
                                            keep_tokens=keep_tokens,  # 只保留keep_tokens中的字,精简原字表
                                            layer_norm_cond=c,
                                            ignore_invalid_weights=True)  # 忽略未初始化的权重

    def forward(self, inputs):
        _, seq_output = self.bert(inputs)  # [btz, seq_len, vocab_size]
        return seq_output

model = Model().to(device)

class CrossEntropyLoss(nn.CrossEntropyLoss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def forward(self, input, target):
        input = input[:, :-1, :].reshape(-1, input.shape[-1])
        target = target[:, 1:].flatten()
        return super().forward(input, target)

model.compile(loss=CrossEntropyLoss(ignore_index=0), optimizer=optim.Adam(model.parameters(), 1e-5))


class RandomSentiment(AutoRegressiveDecoder):
    """根据情感标签(0:负,1:正)随机生成一批句子
    """
    @AutoRegressiveDecoder.wraps(default_rtype='logits')
    def predict(self, inputs, output_ids, states):
        token_ids = output_ids
        segment_ids = torch.zeros_like(token_ids, device=device)
        label = inputs[0]
        return model.predict([token_ids, segment_ids, label])[:, -1, :]

    def generate(self, label, n=1, topp=0.95):
        results = self.random_sample([[label]], n, topp=topp)  # 基于随机采样
        return [tokenizer.decode(ids.cpu().numpy()) for ids in results]


random_sentiment = RandomSentiment(
    start_id=tokenizer._token_start_id,
    end_id=tokenizer._token_end_id,
    maxlen=maxlen,
    device=device
)


def just_show():
    print(u'正面采样:')
    print(random_sentiment.generate(1, 5, 0.95), '\n')
    print(u'负面采样:')
    print(random_sentiment.generate(0, 5, 0.95), '\n')


class Evaluator(Callback):
    """评估与保存
    """
    def __init__(self):
        self.lowest = 1e10

    def on_epoch_end(self, steps, epoch, logs=None):
        # 保存最优
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            # model.save_weights('./best_model.pt')
        # 演示效果
        just_show()


if __name__ == '__main__':

    evaluator = Evaluator()

    model.fit(train_dataloader, epochs=epochs, steps_per_epoch=None, callbacks=[evaluator])
else:

    model.load_weights('./best_model.pt')

"""
正面采样:
[
    u'外观时尚、漂亮、性价比高。',
    u'外观漂亮,配置均衡,比较满意,性价比高,外观漂亮,性能较高。',
    u'我是在大学的时候看到这本书的,所以一直在买。书中的作者是林静蕾,她用自己的口吻写出了一个孩子成长中的心路历程,让我看到了她们成长中的不同之处,以及她们成长过程中的不同境界。让我很欣赏!',
    u'我想这是一本能够告诉读者什么是坏的,而不是教你怎样说话,告诉我什么是错。这里我推荐了《我要讲故事》,这本书是我很喜欢的一本书,我认为它的理由很多,但是,我相信我。如果你从中得到一些改进,或者你已经有了一个明智的决定。',
    u'我们一家五口住的是标间,大床房,大床的床很舒服;而我们在携程网上订了两套大床房,这个酒店的价格还是比较合理的;但是房间的隔音效果不太理想,有点响的声音;酒店门口的地铁在施工中,不方便;但是酒店的门口的出租车不知道是哪个车的,打车不是很方便;酒店外面的停'
]

负面采样:
[
    u'不知道是不是因为电池不太好,不是我不喜欢。',
    u'看了评论才买的. 结果发现不是那么便宜, 价格也不便宜.',
    u'1、外壳不容易沾手印,不容易洗洗2、屏幕有点旧, 不能下载铃声',
    u'我是7月6日订购了《杜拉拉升职记》并已通过银行付款,为什么订单下了两周多至今还未到货?是收货时间太快了,可能就这么过去了吧?',
    u'这本书我是在网上先看了一遍,后来我再看了一遍。感觉作者的文笔实在太烂了,特别是在写他的博客时特别别扭,写得很不专业,特别是他写股票时那个情绪调节的小男孩,简直就是自作聪明的样子,简直就是自作聪明的一种表现!'
]
"""